Skip to content

Commit

Permalink
Implement fused modules (#747)
Browse files Browse the repository at this point in the history
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
casper-hansen and winglian committed Oct 21, 2023
1 parent a21935f commit 15d3a65
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,8 @@ xformers_attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
Expand Down
10 changes: 7 additions & 3 deletions examples/llama-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
micro_batch_size: 1

```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml

accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
```
or

```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
```

To launch a full finetuning with 16-bit precision:

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
```
73 changes: 73 additions & 0 deletions examples/llama-2/fft_optimized.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true

warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
Empty file.
128 changes: 123 additions & 5 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU

from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name

try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
Expand All @@ -38,6 +44,28 @@
LOG = logging.getLogger("axolotl")


def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)


def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)


def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
Expand Down Expand Up @@ -86,6 +114,91 @@ def __init__(self, hidden_size, eps=1e-6):
)


class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""

def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device

# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o

# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)

def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)

new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj

set_module_name(model, name, new_attn)


class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""

def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data

def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)

# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data

set_module_name(model, name, new_mlp)

def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
Expand Down Expand Up @@ -147,9 +260,14 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len)

return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)


def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name

setattr(parent, child_name, value)
10 changes: 6 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ class TrainDatasetMeta:


def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
Expand Down Expand Up @@ -120,6 +117,11 @@ def terminate_handler(_, __, model):

LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access

if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,15 @@ def validate_config(cfg):
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")

if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")

if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")

if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")

if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
Expand All @@ -205,6 +211,9 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")

if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")

if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
Expand Down
14 changes: 14 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)

if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)

if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)

if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
Expand Down
Loading

0 comments on commit 15d3a65

Please sign in to comment.