Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 8, 2024
1 parent 28e505b commit 5d8e178
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/axolotl/utils/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
axolotl samplers module
"""
from .multipack import MultipackBatchSampler # noqa: F401
from .utils import get_dataset_lengths
from .utils import get_dataset_lengths # noqa: F401
3 changes: 3 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
helper util to calculate dataset lengths
"""
import numpy as np


Expand Down
34 changes: 34 additions & 0 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,40 @@ def decode(ids):
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on

def test_sharegpt_chatml(self):
"Make sure the sharegpt/mistral is tokenized and formatted correctly."

# add <|im_start|> token and set <|im_end|> as eos_token
strat = prompt_strat("chatml", self.tokenizer)

def tokenize(conv):
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]

def decode(ids):
return strat.tokenizer.decode(ids)

# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]

# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]

# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]

# No system message, multi-turn
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on

def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
Expand Down

0 comments on commit 5d8e178

Please sign in to comment.