Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a chat_template prompt strategy for DPO #1725

Merged
merged 6 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/llama-3/instruct-dpo-lora-8b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant

dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out

sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
78 changes: 78 additions & 0 deletions src/axolotl/prompt_strategies/dpo/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
DPO prompt strategies for using tokenizer chat templates.
"""

from axolotl.utils.chat_templates import chat_templates


def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)

field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target

def transform_fn(sample, tokenizer=None):
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}

result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)

result["chosen"] = tokenizer.apply_chat_template(
[chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:]

result["rejected"] = tokenizer.apply_chat_template(
[rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:]

return result

return transform_fn
1 change: 1 addition & 0 deletions src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""data handling specific to DPO"""

import inspect
import logging
from functools import partial
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
for token in tokenizer.encode(tokens, add_special_tokens=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I added this since by default I saw that this step was including the bos token all the time. Since that's already included it seemed reasonable to not add it in a second time.

]
return colored_tokens

Expand Down
156 changes: 156 additions & 0 deletions tests/prompt_strategies/test_dpo_chat_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
tests for chat_template prompt strategy
"""

import unittest

import pytest
from datasets import Dataset
from transformers import AutoTokenizer

from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault


@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
],
"chosen": {
"role": "assistant",
"content": "goodbye",
},
"rejected": {
"role": "assistant",
"content": "party on",
},
}
]
)


@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversation": [
{
"speaker": "human",
"text": "hello",
},
{
"speaker": "agent",
"text": "hello",
},
{
"speaker": "human",
"text": "goodbye",
},
],
"better": {
"speaker": "agent",
"text": "goodbye",
},
"worse": {
"speaker": "agent",
"text": "party on",
},
}
]
)


@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"

return tokenizer


class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""

def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"

def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
"message_field_role": "speaker",
"message_field_content": "text",
"roles": {
"user": ["human"],
"assistant": ["agent"],
"system": ["sys"],
},
}
],
}
)
)
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"


if __name__ == "__main__":
unittest.main()
Loading