Skip to content

Commit

Permalink
Feat: Add Qwen (axolotl-ai-cloud#894)
Browse files Browse the repository at this point in the history
* Feat: Add Qwen

* feat: add qwen lora example

* feat: update matrix

* fix: add trust_remote_code

* fix: disable gradient checkpointing

* chore: add warning about gradient checkpointing

* fix: config

* fix: turn off sample packing for this example and reduce seq len

* chore: add comment on seq len
  • Loading branch information
NanoCode012 committed Nov 25, 2023
1 parent d14be20 commit 726371f
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Features:
| XGen ||||||||
| phi ||||||||
| RWKV ||||||||
| Qwen ||||||||


## Quickstart ⚡
Expand Down Expand Up @@ -499,6 +500,7 @@ is_falcon_derived_model:
is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:

# optional overrides to the base model configuration
model_config:
Expand Down
68 changes: 68 additions & 0 deletions examples/qwen/lora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

is_qwen_derived_model: true
trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out

sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
68 changes: 68 additions & 0 deletions examples/qwen/qlora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

is_qwen_derived_model: true
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out

sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:

adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
18 changes: 18 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)

cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)

if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)

Expand Down Expand Up @@ -379,6 +392,11 @@ def validate_config(cfg):
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")

if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ def load_tokenizer(cfg):
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"

# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)

token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")

if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens(
Expand Down

0 comments on commit 726371f

Please sign in to comment.