-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_datasets.py
135 lines (92 loc) · 3.41 KB
/
test_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from datasets import *
def test_next_rationale_step_is_correct_for_zeroed_inputs():
# Given
rs = RationaleStep(x=0, y=0, acc="", carry=1, step=0, n_steps=2)
expected = RationaleStep(x=0, y=0, acc="1", carry=0, step=1, n_steps=2)
# When
actual = next_rational_step(rs)
# Then
assert actual == expected
def test_next_rationale_step_is_correct_penultimate_step():
# Given
rs = RationaleStep(x=0, y=0, acc="86", carry=0, step=2, n_steps=4)
expected = RationaleStep(x=0, y=0, acc="086", carry=0, step=3, n_steps=4)
# When
actual = next_rational_step(rs)
# Then
assert actual == expected
def test_rationale_to_str_is_correct_for_zeroed_inputs():
# Given
r = [
RationaleStep(x=0, y=0, acc="86", carry=0, step=2, n_steps=4),
RationaleStep(x=0, y=0, acc="086", carry=0, step=3, n_steps=4),
]
expected = ", 8 6 C: 0\n0 8 6"
# When
actual = rationale_to_str(r)
# Then
assert actual == expected
def test_corrected_answer_gives_correct_answer_for_many_inputs():
for n_digits in range(1, 30):
for _ in range(10):
# Given
ce = make_corrupted_example(default_n_digits=n_digits)
expected = ce.question.x + ce.question.y
# When
actual = ce.correction_answer
# Then
assert actual == expected, f"Failed for example\n{corrupted_example_to_str(ce)}"
def test_rationale_step_to_str_on_penultimate_step_with_zero_carry():
# Given
rs = RationaleStep(x=0, y=0, acc="29", carry=0, step=2, n_steps=4)
expected = ", 2 9 C: 0"
# When
actual = rationale_step_to_str(rs)
# Then
assert actual == expected
def test_rationale_step_to_str_on_penultimate_step_with_one_carry():
# Given
rs = RationaleStep(x=0, y=0, acc="29", carry=1, step=2, n_steps=4)
expected = ", 2 9 C: 1"
# When
actual = rationale_step_to_str(rs)
# Then
assert actual == expected
def test_rationale_step_to_str_on_final_step():
# Given
rs = RationaleStep(x=0, y=0, acc="086", carry=0, step=4, n_steps=5)
expected = "0 8 6"
# When
actual = rationale_step_to_str(rs)
# Then
assert actual == expected
def test_rationale_is_corrupt_when_is_corrupt_is_set_to_true():
for _ in range(100):
# Given
q = make_question(n_digits=1)
r = make_rationale(q, is_corrupted=True)
# When
c = correct_rationale(q, r)
# Then
assert c is not None, f"Question {q} had a generated rationale {r}"
def test_few_shot_dataset_is_deterministic():
# Given
# When
first_examples = generate_few_shot_examples(n_examples=10, min_n_digits=1, max_n_digits=10)
second_examples = generate_few_shot_examples(n_examples=10, min_n_digits=1, max_n_digits=10)
# Then
assert first_examples == second_examples
def test_arithmetic_dataset_is_deterministic():
# Given
# When
first_examples = list(ArithmeticDataset(n_examples=1_000, min_n_digits=1, max_n_digits=10))
second_examples = list(ArithmeticDataset(n_examples=1_000, min_n_digits=1, max_n_digits=10))
# Then
assert first_examples == second_examples
def test_labelled_arithmetic_dataset_can_produce_corrupted_correction():
# Given
rs = RationaleStep(x=0, y=4, acc='', carry=0, step=0, n_steps=3)
# When
r = complete_rationale(rs, is_corrupted=True)
# Then
assert int(r[-1].acc) != 4