-
Notifications
You must be signed in to change notification settings - Fork 0
/
final_truthfulqa_5.py
136 lines (100 loc) · 4.1 KB
/
final_truthfulqa_5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# to run this script, run inspect eval truthfulqa_5_shot.py --model hf/meta-llama/Meta-Llama-3-8B-Instruct in the cli
from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.scorer import scorer, accuracy, bootstrap_std
from inspect_ai.solver import multiple_choice, system_message, TaskState
from inspect_ai.scorer._target import Target
from inspect_ai.scorer._metric import Score
lang_code = "yo" # change this to the language code for the language you want to test - "am" for amharic, "ha" for hausa, "nso" for northern sotho, "sw" for swahili, "yo" for yoruba
few_shot = 5 # you can change this to any number you want
def sample_to_fewshot(sample):
choices_text = "\n".join(
[f"{chr(65 + i)}. {choice}" for i, choice in enumerate(sample.choices)]
)
return f"Question: {sample.input}\n\nChoices:\n{choices_text}\n\nAnswer: {sample.target[0]}"
@scorer(metrics=[accuracy(), bootstrap_std()])
def logprob_based_scorer():
async def score(state: TaskState, target: Target) -> Score:
model_answer_full = state.output.completion.strip().upper()
# Extract just the letter from the model's answer
model_answer = model_answer_full.split(":")[-1].strip()
# Handle the case where target.target is a list
if isinstance(target.target, list):
correct_answers = [ans.strip().upper() for ans in target.target]
else:
correct_answers = [str(target.target).strip().upper()]
is_correct = model_answer in correct_answers
logprobs = state.output.choices[0].logprobs
if logprobs is None:
return Score(
value="C" if is_correct else "I",
answer=model_answer,
explanation=f"Model answer: {model_answer}, Correct answers: {correct_answers}, No logprobs available",
)
avg_logprob = sum(lp.logprob for lp in logprobs.content) / len(logprobs.content)
return Score(
value="C" if is_correct else "I",
answer=model_answer,
explanation=f"Model answer: {model_answer}, Correct answers: {correct_answers}, Average log probability: {avg_logprob:.4f}",
metadata={"avg_logprob": avg_logprob},
)
return score
def labels_to_positions(labels: list[int]) -> list[str]:
return [chr(ord("A") + i) for i, label in enumerate(labels) if label == 1]
def truthfulqa_task(lang_code, target="mc1", few_shot=5):
def record_to_sample(record):
return Sample(
input=record["question"],
choices=record[f"{target}_targets"]["choices"],
target=labels_to_positions(record[f"{target}_targets"]["labels"]),
)
# Get few-shot examples from the validation split
fewshots = hf_dataset(
path="ebayes/uhura-eval",
name=f"{lang_code}_multiple_choice",
sample_fields=record_to_sample,
split="train",
shuffle=True,
seed=42,
limit=few_shot,
)
fewshot_examples = "\n\n".join([sample_to_fewshot(sample) for sample in fewshots])
dataset = hf_dataset(
path="ebayes/uhura-eval",
name=f"{lang_code}_multiple_choice",
sample_fields=record_to_sample,
split="test",
shuffle=True,
)
multiple_correct = target != "mc1"
return Task(
dataset=dataset,
plan=[
system_message(
f"Here are some example questions and answers:\n\n{fewshot_examples}"
),
multiple_choice(multiple_correct=multiple_correct, shuffle=True),
],
scorer=logprob_based_scorer(),
)
@task
def amharic(target="mc1"):
return truthfulqa_task("am", target)
@task
def english(target="mc1"):
return truthfulqa_task("en", target)
@task
def hausa(target="mc1"):
return truthfulqa_task("ha", target)
@task
def sotho(target="mc1"):
return truthfulqa_task("nso", target)
@task
def swahili(target="mc1"):
return truthfulqa_task("sw", target)
@task
def yoruba(target="mc1"):
return truthfulqa_task("yo", target)
@task
def zulu(target="mc1"):
return truthfulqa_task("zu", target)