Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama3 instruct and chat system prompts #950

Merged
merged 14 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/evaluate_different_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unitxt import get_logger
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import TaskCard
from unitxt.collections_operators import Wrap
from unitxt.inference import IbmGenAiInferenceEngine, IbmGenAiInferenceEngineParams
from unitxt.loaders import LoadFromDictionary
from unitxt.templates import InputOutputTemplate
from unitxt.text_utils import print_dict

logger = get_logger()

# Set up question answer pairs in a dictionary
data = {
"train": [
{"question": "How many days in a week", "answer": "7"},
{"question": "Where is Spain?", "answer": "Europe"},
{"question": "When was IBM founded?", "answer": "1911"},
{"question": "Can pigs fly?", "answer": "No"},
],
"test": [
{"question": "What is the capital of Texas?", "answer": "Austin"},
{"question": "What is the color of the sky?", "answer": "Blue"},
],
}

card = TaskCard(
# Load the data from the dictionary. Data can be also loaded from HF, CSV files, COS and other sources
# with different loaders.
loader=LoadFromDictionary(data=data),
# Use the standard open qa QA task input and output and metrics.
# It has "question" input field and "answers" output field.
# The default evaluation metric used is rouge.
task="tasks.qa.open",
# Because the standand QA tasks supports multiple references in the "answers" field,
# we wrap the raw dataset's "answer" field in a list and store in a the "answers" field.
preprocess_steps=[Wrap(field="answer", inside="list", to_field="answers")],
)

template = InputOutputTemplate(
instruction="Answer the following questions in one word.",
input_format="{question}",
output_format="{answers}",
postprocessors=["processors.lower_case"],
)

dataset = load_dataset(
card=card,
template=template,
format="formats.llama3_instruct",
system_prompt="system_prompts.models.llama2",
num_demos=2,
demos_pool_size=3,
)
test_dataset = dataset["test"]

model_name = "meta-llama/llama-3-70b-instruct"
gen_params = IbmGenAiInferenceEngineParams(max_new_tokens=32)
inference_model = IbmGenAiInferenceEngine(model_name=model_name, parameters=gen_params)

predictions = inference_model.infer(test_dataset)
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)

print_dict(
evaluated_dataset[0],
keys_to_print=[
"source",
"prediction",
"processed_prediction",
"references",
"score",
],
)
55 changes: 44 additions & 11 deletions prepare/formats/models/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,58 @@
from unitxt.formats import SystemFormat

# see: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
# According to: https://huggingface.co/blog/llama3#how-to-prompt-llama-3
# The Instruct versions use the following conversation structure:
# <|begin_of_text|><|start_header_id|>system<|end_header_id|>
#
# {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
# {{ user_message }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
#
# {{ user_msg_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
#
# {{ model_answer_1 }}<|eot_id|>

format = SystemFormat(
demo_format="{source}\n\n{target_prefix}{target}\n\n",
model_input_format="<|begin_of_text|><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
"{instruction}\\N{demos}{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
"{target_prefix}",
demo_format="<|start_header_id|>user<|end_header_id|>\n\n"
"{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
"{target_prefix}{target}<|eot_id|>",
model_input_format="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
+ "{system_prompt}{instruction}"
+ "<|eot_id|>{demos}<|start_header_id|>user<|end_header_id|>\n\n"
"{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}",
)

add_to_catalog(
format,
"formats.llama3_instruct",
overwrite=True,
)

format = SystemFormat(
demo_format="{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
"{target_prefix}{target}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n",
model_input_format="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}\n\n"
"{demos}"
"{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}",
)

add_to_catalog(format, "formats.llama3_chat", overwrite=True)
add_to_catalog(
format,
"formats.llama3_instruct_alt1",
overwrite=True,
)

format = SystemFormat(
demo_format="{source}\n\n{target_prefix}{target}\n\n",
model_input_format="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
"{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
"{instruction}{demos}{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
"{target_prefix}",
model_input_format="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
"{system_prompt}{instruction}"
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
"{demos}"
"{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}",
)

add_to_catalog(format, "formats.llama3_chat_with_system_prompt", overwrite=True)
add_to_catalog(
format,
"formats.llama3_instruct_alt2",
overwrite=True,
)
2 changes: 1 addition & 1 deletion prepare/metrics/llm_as_judge/llamaguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"meta-llama/llama-3-8b-instruct",
"meta-llama/llama-3-70b-instruct",
] # will point to llamaguard2
format = "formats.llama3_chat"
format = "formats.llama3_instruct"
template = "templates.safety.unsafe_content"
task = "rating.single_turn"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unitxt.llm_as_judge import LLMAsJudge

model_list = ["meta-llama/llama-3-8b-instruct", "meta-llama/llama-3-70b-instruct"]
format = "formats.llama3_chat"
format = "formats.llama3_instruct"
template = "templates.response_assessment.rating.mt_bench_single_turn"
task = "rating.single_turn"

Expand Down
11 changes: 11 additions & 0 deletions prepare/system_prompts/tasks/boolqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from unitxt.catalog import add_to_catalog
from unitxt.system_prompts import TextualSystemPrompt

system_prompt = TextualSystemPrompt(
"You are an agent in charge of answering a boolean (yes/no) question. The system presents "
"you with a passage and a question. Read the passage carefully, and then answer yes or no. "
"Think about your answer, and make sure it makes sense. Do not explain the answer. "
"Only say yes or no."
)

add_to_catalog(system_prompt, "system_prompts.boolqa", overwrite=True)
4 changes: 2 additions & 2 deletions src/unitxt/catalog/formats/llama3_chat.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"__type__": "system_format",
"demo_format": "{source}\n\n{target_prefix}{target}\n\n",
"model_input_format": "<|begin_of_text|><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}\\N{demos}{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}"
"demo_format": "<|start_header_id|>user<|end_header_id|>\n\n{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}{target}<|eot_id|>",
"model_input_format": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{instruction}<|eot_id|>{demos}<|start_header_id|>user<|end_header_id|>\n\n{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}"
}
5 changes: 5 additions & 0 deletions src/unitxt/catalog/formats/llama3_instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "system_format",
"demo_format": "<|start_header_id|>user<|end_header_id|>\n\n{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}{target}<|eot_id|>",
"model_input_format": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{instruction}<|eot_id|>{demos}<|start_header_id|>user<|end_header_id|>\n\n{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}"
}
5 changes: 5 additions & 0 deletions src/unitxt/catalog/formats/llama3_instruct_alt1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "system_format",
"demo_format": "{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}{target}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n",
"model_input_format": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}\n\n{demos}{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}"
}
5 changes: 5 additions & 0 deletions src/unitxt/catalog/formats/llama3_instruct_alt2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "system_format",
"demo_format": "{source}\n\n{target_prefix}{target}\n\n",
"model_input_format": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{instruction}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{demos}{source}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_prefix}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
},
"template": "templates.response_assessment.rating.mt_bench_single_turn",
"task": "rating.single_turn",
"format": "formats.llama3_chat",
"format": "formats.llama3_instruct",
"main_score": "llama_3_70b_instruct_ibm_genai_template_mt_bench_single_turn"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
},
"template": "templates.response_assessment.rating.mt_bench_single_turn",
"task": "rating.single_turn",
"format": "formats.llama3_chat",
"format": "formats.llama3_instruct",
"main_score": "llama_3_8b_instruct_ibm_genai_template_mt_bench_single_turn"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
},
"template": "templates.safety.unsafe_content",
"task": "rating.single_turn",
"format": "formats.llama3_chat",
"format": "formats.llama3_instruct",
"main_score": "llama_3_70b_instruct_ibm_genai_template_unsafe_content"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
},
"template": "templates.safety.unsafe_content",
"task": "rating.single_turn",
"format": "formats.llama3_chat",
"format": "formats.llama3_instruct",
"main_score": "llama_3_8b_instruct_ibm_genai_template_unsafe_content"
}
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/system_prompts/boolqa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"__type__": "textual_system_prompt",
"text": "You are an agent in charge of answering a boolean (yes/no) question. The system presents you with a passage and a question. Read the passage carefully, and then answer yes or no. Think about your answer, and make sure it makes sense. Do not explain the answer. Only say yes or no. "
}
2 changes: 1 addition & 1 deletion tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ def _test_grouped_instance_confidence_interval(

def test_llm_as_judge_metric(self):
model_id = "meta-llama/llama-3-8b-instruct"
format = "formats.llama3_chat"
format = "formats.llama3_instruct"
task = "rating.single_turn"
template = "templates.response_assessment.rating.mt_bench_single_turn"

Expand Down
Loading