-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_openai.py
84 lines (77 loc) · 2.09 KB
/
test_openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pytest
from enrichment_models.llms.openai import ChatGPT
from enrichment_models.tasks.labeler.utils import label_similarity_score
def test_():
model = ChatGPT(model_name="gpt-3.5-turbo", temperature=0)
res = model.predict(
instructions="I have a question.",
inputs="In which year the french revolution happened? (note: just return the date as, nothing else)",
)
# remove punctuation (ex: a dot at the end)
res = "".join(c for c in res if c.isdigit())
assert res == "1789"
def test_chatgpt():
model = ChatGPT(model_name="gpt-4", temperature=0)
res = model.predict(
instructions="I have a question.",
inputs="In which year the french revolution happened? (note: just return the date, nothing else)",
)
# remove punctuation (ex: a dot at the end)
res = "".join(c for c in res if c.isdigit())
assert res == "1789"
def test_label_similarity_score():
preds = [
"App stores",
"Buy now, pay later",
"Food and Drink",
"Clothing",
"Cafes and coffee shops",
"ATM/bank withdrawal",
"Firearms",
"Loan repayment",
"Mortgage",
"Mortgage",
"Auto lease payment",
"Property rental",
"Education",
"Gambling",
"Insurance",
]
ground_truths = [
"Software",
"Books, newsletters, newspapers",
"Liquor",
"Recreational goods",
"Food and Drink",
"Food and Drink",
"Childcare",
"Laundry",
"Mortgage",
"Insurance",
"Auto loan repayment",
"Refunds",
"Electronics",
"Media",
"Sport and fitness",
]
scores = label_similarity_score(preds, ground_truths, average_reduction=False)
assert scores == pytest.approx(
[
0.9,
0.09,
0.9,
0.5,
0.9,
0.14,
0.05,
0.12,
1,
0.9,
0.9,
0.04,
0.34,
0.40,
0.21,
],
abs=0.01,
)