-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
106 lines (90 loc) · 2.88 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os, pickle
import sys
sys.path.append("./model/")
sys.path.append("./preprocess/")
from transformers import Trainer, TrainingArguments
import numpy as np
import evaluate
from preprocess.roundataset import roundataset
from model.modeling_deberta_visual import DebertaForPhotobookListener
from model.variables import (
EPOCHS, CKPT_DIR,
BATCH_SIZE, PEAK_LR, WARMUP_STEPS, WEIGHT_DECAY,
)
metric = evaluate.load("accuracy")
ckpt_dir = sys.argv[1]
res_dir = "results/"
if not os.path.exists(res_dir):
os.makedirs(res_dir)
def compute_metrics(eval_pairs):
predictions, labels = eval_pairs
# print (labels[0])
predictions = np.argmax(predictions[..., 1:], axis=-1) + 1
true_predictions = []
true_labels = []
# fetch last timestep outputs only
bsize, seqlen = predictions.shape[0], predictions.shape[1]
for b in range(bsize):
for pos in range(seqlen - 1, -1, -1):
if labels[b, pos, 0] != -100:
true_predictions.extend(predictions[b, pos])
true_labels.extend(labels[b, pos])
break
# print (true_predictions, true_labels)
results = metric.compute(predictions=true_predictions, references=true_labels)
# print (results)
pickle.dump(
{
"labels": true_labels,
"preds": true_predictions
},
open(
os.path.join(
res_dir,
os.path.basename(ckpt_dir) + '.pkl'
),
"wb"
)
)
return results
if __name__ == '__main__':
# test_dset = roundataset(
# 'data/test_clean_sections.pickle',
# 'data/image_feats.pickle'
# )
test_dset = roundataset(
'data/test_clean_sections.pickle',
'data/image_feats.pickle',
separate_images=False,
dense_learning_signals=True,
)
print ("[info] test set loaded, len =", len(test_dset))
model = DebertaForPhotobookListener.from_pretrained(ckpt_dir)
trainer = Trainer(
model,
TrainingArguments(
output_dir=ckpt_dir,
do_train=False,
do_eval=True,
per_device_eval_batch_size=BATCH_SIZE,
per_device_train_batch_size=BATCH_SIZE,
learning_rate=PEAK_LR,
weight_decay=WEIGHT_DECAY,
warmup_steps=WARMUP_STEPS,
num_train_epochs=EPOCHS,
evaluation_strategy='epoch',
save_strategy='epoch',
metric_for_best_model="eval_accuracy",
save_total_limit=3,
dataloader_num_workers=8,
logging_steps=50,
load_best_model_at_end=True,
),
eval_dataset=test_dset,
compute_metrics=compute_metrics,
)
# # Evaluation
metrics = trainer.evaluate()
metrics["eval_samples"] = len(test_dset)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)