-
Notifications
You must be signed in to change notification settings - Fork 1
/
keras_to_tf.py
64 lines (54 loc) · 2.84 KB
/
keras_to_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pdb
import os
from keras.models import model_from_json
from keras import backend as K
from keras.models import load_model
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.core.framework import graph_pb2
from tensorflow.python.platform import gfile
from kapre.time_frequency import Spectrogram, Melspectrogram
from kapre.utils import Normalization2D
from kapre.augmentation import AdditiveNoise
sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0) # all new operations will be in test mode from now on
# serialize the model and get its weights, for quick re-building
previous_model = load_model('models/birdwatcher.h5', custom_objects={"Spectrogram": Spectrogram, "AdditiveNoise": AdditiveNoise})
previous_model.save_weights('models/birdwatcher_weights.h5')
model_json = previous_model.to_json()
new_model = model_from_json(model_json, custom_objects={"Spectrogram": Spectrogram, "AdditiveNoise": AdditiveNoise})
new_model.load_weights('models/birdwatcher_weights.h5')
checkpoint_prefix = os.path.join("models", "saved_checkpoint")
export_path = 'models/'
export_version = 1
checkpoint_state_name = "checkpoint_state"
saver = tf.train.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
signature = exporter.classification_signature(input_tensor=new_model.input,
scores_tensor=new_model.output)
model_exporter.init(sess.graph.as_graph_def(),
default_graph_signature=signature)
model_exporter.export(export_path, tf.constant(export_version), sess)
tf.train.write_graph(sess.graph.as_graph_def(), 'models/', 'birdwatcher.pbtxt')
checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=0, latest_filename=checkpoint_state_name)
input_graph_name = "birdwatcher.pbtxt"
output_graph_name = "frozen.pb"
input_graph_path = os.path.join("models", input_graph_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = os.path.join("models", 'saved_checkpoint') + "-0"
# Note that we this normally should be only "output_node"!!!
output_node_names = "loss/Softmax"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join("models", output_graph_name)
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, input_checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, False)
# python env/lib/python3.6/site-packages/tensorflow/python/tools/optimize_for_inference.py --input=models/frozen.pb --output=models/inference.pb --input_names=input_1 --output_names=loss/Softmax