-
Notifications
You must be signed in to change notification settings - Fork 27
/
test.py
80 lines (65 loc) · 2.91 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author: lapis-hong
# @Date : 2018/1/15
"""Wide and Deep Model Evaluation"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
import sys
import time
import tensorflow as tf
from wide_resdnn.build_estimator import build_estimator
from wide_resdnn.dataset import input_fn
from wide_resdnn.read_conf import Config
from wide_resdnn.util import elapse_time
# Config file path, change it to use different data.
CONFIG = Config("conf/criteo")
# CONFIG = Config("conf/avazu")
parser = argparse.ArgumentParser(description='Evaluate Wide and Deep Model.')
parser.add_argument(
'--conf_dir', type=bool, default="conf/criteo",
help='Path to configuration.')
parser.add_argument(
'--test_data', type=str, default=CONFIG.train["test_data"],
help='Evaluating data dir.')
parser.add_argument(
'--model_dir', type=str, default=CONFIG.train["model_dir"],
help='Model checkpoint dir for evaluating.')
parser.add_argument(
'--model_type', type=str, default=CONFIG.train["model_type"],
help="Valid model types: {'wide', 'deep', 'wide_deep'}.")
parser.add_argument(
'--batch_size', type=int, default=CONFIG.train["batch_size"],
help='Number of examples per batch.')
parser.add_argument(
'--checkpoint_path', type=str, default=CONFIG.train["checkpoint_path"],
help="Path of a specific checkpoint to evaluate. If None, the latest checkpoint in model_dir is used.")
def main(_):
print("Using TensorFlow version %s, need TensorFlow 1.10 or later." % tf.__version__)
# assert "1.4" <= tf.__version__, "TensorFlow r1.4 or later is needed"
print('Model type: {}'.format(FLAGS.model_type))
model_dir = os.path.join(FLAGS.model_dir, FLAGS.model_type)
print('Model directory: {}'.format(model_dir))
model = build_estimator(model_dir, FLAGS.model_type)
tf.logging.info('Build estimator: {}'.format(model))
tf.logging.info('='*30+' START TESTING'+'='*30)
s_time = time.time()
results = model.evaluate(input_fn=lambda: input_fn(FLAGS.test_data, 1, FLAGS.batch_size, False),
steps=None, # Number of steps for which to evaluate model.
hooks=None,
checkpoint_path=FLAGS.checkpoint_path, # If None, the latest checkpoint is used.
name=None)
tf.logging.info('='*30+'FINISH TESTING, TAKE {}'.format(elapse_time(s_time))+'='*30)
# Display evaluation metrics
print('-' * 80)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
if __name__ == '__main__':
# Set to INFO for tracking training, default is WARN. ERROR for least messages
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)