Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falcon embeddings #1149

Merged
merged 6 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
2 changes: 1 addition & 1 deletion examples/falcon/config-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
12 changes: 12 additions & 0 deletions src/axolotl/monkeypatch/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
2 changes: 2 additions & 0 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]
37 changes: 23 additions & 14 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)

LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()

if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
Expand Down Expand Up @@ -434,18 +442,13 @@ def load_model(
if not cfg.sample_packing:
if cfg.s2_attention:
pass
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2"]:
if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
Expand All @@ -461,7 +464,11 @@ def load_model(
model_config.fused_dense = True

try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -755,8 +762,10 @@ def find_all_linear_names(model):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)

return list(lora_module_names)

Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")

if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column")
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
Expand Down
112 changes: 112 additions & 0 deletions tests/e2e/patched/test_falcon_samplepack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
E2E tests for falcon
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestFalconPatched(unittest.TestCase):
"""
Test case for Falcon models
"""

@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_mixtral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_qlora(self, temp_dir):
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"adapter": "qlora",
Expand All @@ -57,7 +58,6 @@ def test_qlora(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
Expand All @@ -76,6 +76,7 @@ def test_ft(self, temp_dir):
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {},
Expand All @@ -95,7 +96,6 @@ def test_ft(self, temp_dir):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
Expand Down
Loading