Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize using ckpt files #2

Open
harsh-agar opened this issue Oct 4, 2018 · 1 comment
Open

Visualize using ckpt files #2

harsh-agar opened this issue Oct 4, 2018 · 1 comment

Comments

@harsh-agar
Copy link

harsh-agar commented Oct 4, 2018

Hi @conan7882
I wanted to know how can we visualize custom trained vgg19 model in ckpt format.
Thanks

@Sirius083
Copy link

you can
(1) generate .pb file from checkpoint file

def generate_pb_file(model_dir,output_node_name,option = 'latest'):
    '''
    parameters: 
    model_dir: the output model directory
    input_node_name: input name in graph
    output_node_name: output node name in graph
    option: 'latest': generate pb file from latest checkpoint
            'min'   : geberate pb file from minimum-validation error checkpont

    '''    
    tf.reset_default_graph()
    config = tf.ConfigProto(allow_soft_placement = True)
    

    # one meta file in each saved dierctory
    allfiles = os.listdir(model_dir)
    pb_file_name = [s for s in allfiles if s.endswith('.meta')]
    assert len(pb_file_name) == 1 ,'more than one meta file'
    pb_file_name = pb_file_name[0]
    meta_path = os.path.join(model_dir,pb_file_name)

    with tf.Session(config = config) as sess:

        # Restore the graph
        # clear_divices: do not care which GPU to use
        saver = tf.train.import_meta_graph(meta_path,clear_devices=True)
        
        # count total number of parameters in the model
        total_param_count = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
 
        # Load weights
        if option == 'latest': # restore from latest checkpoint
           saver.restore(sess,tf.train.latest_checkpoint(model_dir))
           output_name = 'output_latest.pb'
        elif option == 'min':  # restore from min validation error
             saver.restore(sess, os.path.join(model_dir, 'min-validation_error'))
             output_name = 'output_min.pb'
        else:
             import sys
             sys.exit('Do not have the specified checkpoint file')
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_name)
        
        # generate corresponding file in the model checkpoint directory
        graph_pb_path = os.path.join(model_dir,output_name)
        # Save the frozen graph
        with open(graph_pb_path, 'wb') as f:
             f.write(frozen_graph_def.SerializeToString())

    print('Save model pb file to path, ', graph_pb_path)

(2) load your graph

    with tf.gfile.GFile(graph_pb_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")

(3) feed your target layer with pre-processed image

with tf.Session(graph=graph) as sess:
     channel_value = sess.run(channel, feed_dict = {inputs: image_value})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants