Skip to content

Commit

Permalink
add test for train_on_inputs for sharegpt
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 6, 2024
1 parent 3a5d71d commit 1669c29
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TestSharegpt:
Test class for sharegpt prompter
"""

def test_something(self, sharegpt_dataset, tokenizer):
def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
Expand All @@ -94,3 +94,59 @@ def test_something(self, sharegpt_dataset, tokenizer):
32001, 13892, 13, 12684, 17664, 32000 # gpt
]
# fmt: on

def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation="chatml",
role_key_model=None,
role_key_human=None,
),
tokenizer,
True, # train_on_inputs
2048, # sequence_len
)

dataset_wrapper = TokenizedPromptDataset(
strategy, sharegpt_dataset, process_count=1
)

labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, # human
-100, -100, 13, 21558, 32000, # gpt
-100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 12684, 17664, 32000 # gpt
]
# fmt: on

# def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
# strategy = SimpleShareGPTPromptTokenizingStrategy(
# ShareGPTPrompterV2(
# conversation="chatml",
# role_key_model=None,
# role_key_human=None,
# ),
# tokenizer,
# False, # train_on_inputs
# 2048, # sequence_len
# )
#
# dataset_wrapper = TokenizedPromptDataset(
# strategy, sharegpt_dataset, process_count=1
# )
#
# labels = dataset_wrapper[0]["labels"]
# # fmt: off
# assert labels == [
# 1, # bos
# 32001, 1587, 13, 25997, 32000, # system
# 32001, 2188, 13, 21558, 32000, # human
# 32001, 13892, 13, 21558, 32000, # gpt
# 32001, 2188, 13, 12684, 17664, 32000, # human
# 32001, 13892, 13, 12684, 17664, 32000 # gpt
# ]
# # fmt: on

0 comments on commit 1669c29

Please sign in to comment.