Skip to content

anthony0727/Image-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

80 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Image-Classification

classification of popular image datasets(cifar10, ImageNet, etc.)

usage :

# Load Module
import os

import tensorflow as tf
import gc

from model.vggnet import VGGNet
from util import train

# Declare Static Variables
vgg11_log_dir = './vgg/vgg11'
vgg13_log_dir = './vgg/vgg13'

input_shape = (32, 32, 3)
n_class = 100

# Build VGG11

pretrained_vgg = VGGNet(11)
pretrained_vgg.build(input_shape, n_class)

# Pretrain VGG11
with pretrained_vgg.graph.as_default() as graph:
    loss = graph.get_tensor_by_name('loss:0')
    lr = tf.placeholder_with_default(1e-2, (), name='learning_rate')
    global_step = tf.train.get_or_create_global_step()

    with tf.variable_scope('optimizer'):
        tf.train.MomentumOptimizer(lr, 0.9).minimize(loss, global_step)

    sess = tf.Session(graph=graph)
    sess = train(sess, os.path.join(vgg11_log_dir, 'log'))

with pretrained_vgg.graph.as_default() as graph:
    saver = tf.train.Saver()
    saver.save(sess, vgg11_log_dir)

# Reserve some memory

del pretrained_vgg.graph
gc.collect()

# Build VGG13
reconstructed_vgg = VGGNet(13)
reconstructed_vgg.build(input_shape, n_class)

# Collect weights to transfer

with reconstructed_vgg.graph.as_default():
    transfer_weights = \
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, '(VGGBLOCK-1/conv1)') + \
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, '(VGGBLOCK-2/conv1)') + \
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, '(VGGBLOCK-3/conv1|conv2)') + \
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'FC/')

# Restore pretrained weights

with reconstructed_vgg.graph.as_default() as graph:
    sess = tf.Session(graph=graph)

    saver = tf.train.Saver(var_list=transfer_weights)
    saver.restore(sess, vgg11_log_dir)

# Train VGG13

with reconstructed_vgg.graph.as_default() as graph:
    loss = graph.get_tensor_by_name('loss:0')
    lr = tf.placeholder_with_default(1e-2, (), name='learning_rate')
    global_step = tf.train.get_or_create_global_step()

    with tf.variable_scope('optimizer'):
        tf.train.AdamOptimizer(lr).minimize(loss, global_step)

    with tf.variable_scope('initialization'):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])

        uninitialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    if len(uninitialized_vars):
        sess.run(tf.variables_initializer(uninitialized_vars))

    sess = tf.Session(graph=graph)
    sess = train(sess, os.path.join(vgg13_log_dir, 'log'))
    saver = tf.train.Saver()
    saver.save(sess, vgg13_log_dir)

About

classify popular datasets(cifar10, ImageNet, etc.)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published