Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fused modules #747

Merged
merged 23 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,8 @@ xformers_attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
Expand Down
10 changes: 7 additions & 3 deletions examples/llama-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
micro_batch_size: 1

```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml

accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
```
or

```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
```

To launch a full finetuning with 16-bit precision:

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
```
73 changes: 73 additions & 0 deletions examples/llama-2/fft_optimized.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true

warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
Empty file.
128 changes: 123 additions & 5 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU

from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name

try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
Expand All @@ -38,6 +44,28 @@
LOG = logging.getLogger("axolotl")


def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)


def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)


def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
Expand Down Expand Up @@ -86,6 +114,91 @@ def __init__(self, hidden_size, eps=1e-6):
)


class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""

def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device

# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o

# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)

def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)

new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj

set_module_name(model, name, new_attn)


class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""

def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data

def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)

# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data

set_module_name(model, name, new_mlp)

def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
Expand Down Expand Up @@ -147,9 +260,14 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len)

return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)


def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name

setattr(parent, child_name, value)
10 changes: 6 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ class TrainDatasetMeta:


def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
Expand Down Expand Up @@ -119,6 +116,11 @@ def terminate_handler(_, __, model):

LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access

if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,15 @@ def validate_config(cfg):
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")

if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")

if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")

if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")

if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
Expand All @@ -205,6 +211,9 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")

if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")

if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
Expand Down
14 changes: 14 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)

if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)

if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)

if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
Expand Down
Loading