Skip to content

Commit

Permalink
fix mistral prompt assembly (#982)
Browse files Browse the repository at this point in the history
* fix mistral prompts

* fix spacing

* remove elif
  • Loading branch information
hamelsmu committed Dec 21, 2023
1 parent 161bcb6 commit 7bbaac9
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 47 deletions.
24 changes: 23 additions & 1 deletion src/axolotl/monkeypatch/fastchat_conversation_turns.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
Expand All @@ -101,6 +101,28 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
contains_sys_msg = False
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(
system_message=" " + self.system_message
)
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message and i == 0 and not contains_sys_msg:
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
elif message:
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
Expand Down
131 changes: 85 additions & 46 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import unittest
from copy import deepcopy
from pathlib import Path
from typing import Optional

Expand All @@ -25,6 +26,50 @@

LOG = logging.getLogger("axolotl")

test_data = {
"multi_turn_sys": {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
},
"single_turn_sys": {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
},
"single_turn_no_sys": {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
},
"multi_turn_no_sys": {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
},
}


def prompt_strat(conversation, tokenizer):
"Helper function to create a prompt strategy for testing."
prompter = ShareGPTPrompterV2(conversation=conversation)
return ShareGPTPromptTokenizingStrategy(
prompter,
tokenizer,
False,
2048,
)


class TestPromptTokenizationStrategies(unittest.TestCase):
"""
Expand Down Expand Up @@ -116,74 +161,68 @@ def test_sharegpt_warnings_turns(self):

def test_sharegpt_llama(self):
"Make sure the sharegpt/llama is tokenized and formatted correctly."
prompter = ShareGPTPrompterV2(conversation="llama-2")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
strat = prompt_strat("llama-2", self.tokenizer)

def tokenize(conv):
return strat.tokenize_prompt(conv)["input_ids"]
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]

def decode(ids):
return strat.tokenizer.decode(ids)

# Multi-turn conversations
multi_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
# fmt: off
mt_ids = tokenize(multi_turn_conv)
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]

# Single-turn conversations
single_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}

st_ids = tokenize(single_turn_conv)
# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]

# No system message, single-turn
no_sys_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}

ns_ids = tokenize(no_sys_conv)
ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]

# No system message, multi-turn
no_sys_mt_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
ns_mt_ids = tokenize(no_sys_mt_conv)
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on

def test_sharegpt_mistral(self):
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
strat = prompt_strat("mistral", self.tokenizer)

def tokenize(conv):
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]

def decode(ids):
return strat.tokenizer.decode(ids)

# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]

# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]

# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]

# No system message, multi-turn
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on

def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
Expand Down

0 comments on commit 7bbaac9

Please sign in to comment.