Skip to content

Commit

Permalink
First pass chat template (backend only)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxusmusti committed Jun 20, 2024
1 parent c6bffd4 commit 8141247
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 36 deletions.
23 changes: 23 additions & 0 deletions src/instructlab/training/chat_templates/ibm_generic_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
system="<|system|>",
user="<|user|>",
assistant="<|assistant|>",
eos="<|endoftext|>",
pad="<|pad|>"
)

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)
23 changes: 23 additions & 0 deletions src/instructlab/training/chat_templates/mistral_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
bos="<s>",
eos="</s>",
user="[INST]",
assistant="[/INST]",


)

CHAT_TEMPLATE = (
"{{ '<s>' }}"
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{ message['content'] + '</s>' }}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + '</s>'}}"
"{% endif %}"
"{% endfor %}"
)
14 changes: 1 addition & 13 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ def check_valid_sample(
if not any(token in whole_sentence_tk for token in special_tokens):
return True

# first token should be system_token
if whole_sentence_tk[0] != system_tk:
print("\033[91mfirst token is not a system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

# check there's only one system_token
if (np.array(whole_sentence_tk) == system_tk).sum() != 1:
print("\033[91mthere are more than one system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

whole_sentence_tk = np.array(whole_sentence_tk)
user_token_index = (whole_sentence_tk == user_tk).nonzero()[0]
assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0]
Expand Down Expand Up @@ -121,7 +109,7 @@ def unmask_only_assistant_responses(
whole_sentence = chosen_token["input_ids"][:sentence_legth].clone()

# pre-training mode
if system_tk not in whole_sentence:
if not (system_tk in whole_sentence or user_token in whole_sentence or assist_token in whole_sentence):
return labels

labels[:sentence_legth] = -100
Expand Down
40 changes: 17 additions & 23 deletions src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,38 @@

@dataclass
class SpecialTokens:
system: str = field(default="<|system|>")
system: str = field(default=None)
user: str = field(default="<|user|>")
assistant: str = field(default="<|assistant|>")
eos: str = field(default="<|endoftext|>")
pad: str = field(default="<|pad|>")
pad: str = field(default=None)
bos: str = field(default="<|begginingoftext|>")


SPECIAL_TOKENS = SpecialTokens()

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)
#TODO: Replace with specified template path
from instructlab.training.chat_templates.ibm_generic_tmpl import SPECIAL_TOKENS, CHAT_TEMPLATE


def setup_tokenizer(
model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE
):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)

if not SPECIAL_TOKENS.pad:
SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos
tokenizer.add_special_tokens(
{"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad}
{"bos_token": SPECIAL_TOKENS.bos, "eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad}
)

if SPECIAL_TOKENS.system:
add_token_list = [SPECIAL_TOKENS.system]
else:
add_token_list = []
add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant])

tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS.system,
SPECIAL_TOKENS.user,
SPECIAL_TOKENS.assistant,
]
"additional_special_tokens": add_token_list
}
)
if getattr(tokenizer, "add_bos_token", False) or getattr(
Expand Down

0 comments on commit 8141247

Please sign in to comment.