-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Visualize using ckpt files #2
Comments
you can def generate_pb_file(model_dir,output_node_name,option = 'latest'):
'''
parameters:
model_dir: the output model directory
input_node_name: input name in graph
output_node_name: output node name in graph
option: 'latest': generate pb file from latest checkpoint
'min' : geberate pb file from minimum-validation error checkpont
'''
tf.reset_default_graph()
config = tf.ConfigProto(allow_soft_placement = True)
# one meta file in each saved dierctory
allfiles = os.listdir(model_dir)
pb_file_name = [s for s in allfiles if s.endswith('.meta')]
assert len(pb_file_name) == 1 ,'more than one meta file'
pb_file_name = pb_file_name[0]
meta_path = os.path.join(model_dir,pb_file_name)
with tf.Session(config = config) as sess:
# Restore the graph
# clear_divices: do not care which GPU to use
saver = tf.train.import_meta_graph(meta_path,clear_devices=True)
# count total number of parameters in the model
total_param_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
# Load weights
if option == 'latest': # restore from latest checkpoint
saver.restore(sess,tf.train.latest_checkpoint(model_dir))
output_name = 'output_latest.pb'
elif option == 'min': # restore from min validation error
saver.restore(sess, os.path.join(model_dir, 'min-validation_error'))
output_name = 'output_min.pb'
else:
import sys
sys.exit('Do not have the specified checkpoint file')
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
output_node_name)
# generate corresponding file in the model checkpoint directory
graph_pb_path = os.path.join(model_dir,output_name)
# Save the frozen graph
with open(graph_pb_path, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
print('Save model pb file to path, ', graph_pb_path) (2) load your graph with tf.gfile.GFile(graph_pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="") (3) feed your target layer with pre-processed image with tf.Session(graph=graph) as sess:
channel_value = sess.run(channel, feed_dict = {inputs: image_value}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @conan7882
I wanted to know how can we visualize custom trained vgg19 model in ckpt format.
Thanks
The text was updated successfully, but these errors were encountered: