-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
116 lines (107 loc) · 4.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/7/15 15:47
# @Author : TheTao
# @Site :
# @File : test.py
# @Software: PyCharm
import time
import warnings
import json as js
import tensorflow as tf
from model import Model
from params_utils import get_params
from build_inputs import input_from_line_with_feature
from data_utils import BatchManager, get_logger, get_dict, test_ner, result_to_json
warnings.filterwarnings("ignore")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# 批量评估函数
def evaluate(sess, param, model, name, batchmanager, logger):
# 拿到对应的一个批次测试结果集
ner_results = model.evaluate(sess, batchmanager)
# 预测结果保存到结果集
eval_lines = test_ner(ner_results, param.result_path)
# 这里是打印报告结果
for line in eval_lines:
logger.info(line)
# 这里就拿到F1值
f1 = float(eval_lines[1].strip().split()[-1])
# 这里返回最佳的F1值
if name == "dev":
best_test_f1 = model.best_dev_f1.eval()
if f1 > best_test_f1:
tf.assign(model.best_dev_f1, f1).eval()
logger.info("new best dev f1 score:{:>.3f}".format(f1))
elif name == "test":
best_test_f1 = model.best_test_f1.eval()
if f1 > best_test_f1:
tf.assign(model.best_test_f1, f1).eval()
logger.info("new best test f1 score:{:>.3f}".format(f1))
def test(param):
# 检查参数
assert param.clip < 5.1, "gradient clip should't be too much"
assert 0 <= param.dropout < 1, "dropout rate between 0 and 1"
assert param.lr > 0, "learning rate must larger than zero"
# 获取batch_manager
test_manager = BatchManager(param.test_batch_size, name='test')
number_dataset = test_manager.len_data
print("total of number test data is {}".format(number_dataset))
# 配置日志
logger = get_logger(param.test_log_file)
# 读取字典
mapping_dict = get_dict(param.dict_file)
# 搭建模型
model = Model(param, mapping_dict)
# 配置GPU参数
gpu_config = tf.ConfigProto()
with tf.Session(config=gpu_config) as sess:
logger.info("start testing...")
start = time.time()
# 首先检查模型是否存在
ckpt_path = param.ckpt_path
ckpt = tf.train.get_checkpoint_state(ckpt_path)
# 看是否存在训练好的模型
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
logger.info("Reading model parameters from {}".format(ckpt.model_checkpoint_path))
# 如果存在就进行重新加载
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
logger.info("Cannot find the ckpt files!")
# 开始评估
evaluate(sess, param, model, "test", test_manager, logger)
logger.info("The best_f1 on test_dataset is {:.2f}".format(model.best_test_f1.eval()))
logger.info('Time test for {:.2f} batch is {:.2f} sec\n'.format(param.test_batch_size, time.time() - start))
# 评估单个句子
def predict_line(param):
# 初始化日志对象
logger = get_logger(param.test_log_file)
tf_config = tf.ConfigProto()
# 读取字典
mapping_dict = get_dict(param.dict_file)
# 根据保存的模型读取模型
model = Model(param, mapping_dict)
# 开始测试
with tf.Session(config=tf_config) as sess:
# 首先检查模型是否存在
ckpt_path = param.ckpt_path
ckpt = tf.train.get_checkpoint_state(ckpt_path)
# 看是否存在训练好的模型
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
logger.info("Reading model parameters from {}".format(ckpt.model_checkpoint_path))
# 如果存在就进行重新加载
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
logger.info("Cannot find the ckpt files!")
while True:
# 反复输入句子进行预测
line = input("请输入测试句子:")
raw_inputs, model_inputs = input_from_line_with_feature(line)
tag = model.evaluate_line(sess, model_inputs)
result = result_to_json(raw_inputs, tag)
result = js.dumps(result, ensure_ascii=False, indent=4, separators=(',', ': '))
with open('./result/result.json', 'w', encoding='utf-8') as f:
f.write(result)
print("预测结果为:{}".format(result))
if __name__ == '__main__':
params = get_params()
predict_line(params)