-
Notifications
You must be signed in to change notification settings - Fork 1
/
freeze_model.py
76 lines (59 loc) · 2.79 KB
/
freeze_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Model exportation.
======================
Exports/Freezes training checkpoints for future detections.
"""
import tensorflow as tf
from utils import *
from config import *
from argparse import ArgumentParser
from object_detection import exporter
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2
__description__ = "Exports/Freezes training checkpoints for future detections."
# Parse args.
parser = ArgumentParser(description=__description__)
parser.add_argument("--input-type", type=str, choices=("image_tensor", "tf_example", "encoded_image_string_tensor"),
default="image_tensor",
help="Type of input node, default is {default}.".format(default="image_tensor"))
parser.add_argument("--config-path", type=str, default="training/config/smartbin_pipeline.config",
help="Path of the model config file, default is {default}.".format(
default="training/config/smartbin_pipeline.config"))
parser.add_argument("--checkpoint-prefix", type=str, default="model.ckpt",
help="Path to trained checkpoint, typically of the form path/to/model.ckpt, default is {default}.".format(
default="model.ckpt"))
parser.add_argument("--output-directory", type=str, default="training/outputs",
help="Path to write outputs, default is {default}.".format(
default="training/outputs"))
parser.add_argument("--config-override", type=str, default="",
help="Override pipeline config file content, default is {default}.".format(
default="''"))
parser.add_argument("--write-inference-graph", type=bool, default=False,
help="Write inference graph to disk, default is {default}.".format(
default=False))
args = parser.parse_args()
# slim = tf.contrib.slim
def main(_):
"""
Main program.
:param _: unused parameter.
:return: void.
"""
# Retrieve latest checkpoint prefix.
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
latest_checkpoint = find_latest_checkpoint(dir=CHECKPOINTS_DIR, prefix=args.checkpoints_prefix)
if latest_checkpoint is None:
return
# Read model config file.
with tf.gfile.GFile(args.config_file, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
# Override config file if any updates are provided.
text_format.Merge(args.config_override, pipeline_config)
# Freeze checkpoint file.
exporter.export_inference_graph(
args.input_type, pipeline_config, latest_checkpoint,
args.output_directory, write_inference_graph=args.write_inference_graph)
if __name__ == "__main__":
tf.app.run()