Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing deprecated tensorflow api #76

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,14 @@ def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True,
self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32)

# create dataset
data = Dataset.from_tensor_slices((self.img_paths, self.labels))
data = tf.data.Dataset.from_tensor_slices((self.img_paths, self.labels))

# distinguish between train/infer. when calling the parsing functions
if mode == 'training':
data = data.map(self._parse_function_train, num_threads=8,
output_buffer_size=100*batch_size)
data = data.map(self._parse_function_train, num_parallel_calls=8).prefetch(100 * batch_size)

elif mode == 'inference':
data = data.map(self._parse_function_inference, num_threads=8,
output_buffer_size=100*batch_size)
data = data.map(self._parse_function_inference, num_parallel_calls=8).prefetch(100 * batch_size)

else:
raise ValueError("Invalid mode '%s'." % (mode))
Expand Down
23 changes: 10 additions & 13 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class on any given dataset. Specify the configuration settings at the
"""

# Path to the textfiles for the trainings and validation set
train_file = '/path/to/train.txt'
val_file = '/path/to/val.txt'
train_file = 'train.txt'
val_file = 'val.txt'

# Learning params
learning_rate = 0.01
Expand All @@ -44,8 +44,8 @@ class on any given dataset. Specify the configuration settings at the
display_step = 20

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "/tmp/finetune_alexnet/tensorboard"
checkpoint_path = "/tmp/finetune_alexnet/checkpoints"
filewriter_path = "tmp/tensorboard"
checkpoint_path = "tmp/checkpoints"

"""
Main Part of the finetuning Script.
Expand Down Expand Up @@ -93,8 +93,8 @@ class on any given dataset. Specify the configuration settings at the

# Op for calculating the loss
with tf.name_scope("cross_ent"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score,
labels=y))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=score,
labels=y))

# Train op
with tf.name_scope("train"):
Expand All @@ -117,7 +117,6 @@ class on any given dataset. Specify the configuration settings at the
# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)


# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
Expand All @@ -136,12 +135,11 @@ class on any given dataset. Specify the configuration settings at the
saver = tf.train.Saver()

# Get the number of training/validation steps per epoch
train_batches_per_epoch = int(np.floor(tr_data.data_size/batch_size))
train_batches_per_epoch = int(np.floor(tr_data.data_size / batch_size))
val_batches_per_epoch = int(np.floor(val_data.data_size / batch_size))

# Start Tensorflow session
with tf.Session() as sess:

# Initialize all variables
sess.run(tf.global_variables_initializer())

Expand All @@ -158,7 +156,7 @@ class on any given dataset. Specify the configuration settings at the
# Loop over number of epochs
for epoch in range(num_epochs):

print("{} Epoch number: {}".format(datetime.now(), epoch+1))
print("{} Epoch number: {}".format(datetime.now(), epoch + 1))

# Initialize iterator with the training dataset
sess.run(training_init_op)
Expand All @@ -179,15 +177,14 @@ class on any given dataset. Specify the configuration settings at the
y: label_batch,
keep_prob: 1.})

writer.add_summary(s, epoch*train_batches_per_epoch + step)
writer.add_summary(s, epoch * train_batches_per_epoch + step)

# Validate the model on the entire validation set
print("{} Start validation".format(datetime.now()))
sess.run(validation_init_op)
test_acc = 0.
test_count = 0
for _ in range(val_batches_per_epoch):

img_batch, label_batch = sess.run(next_batch)
acc = sess.run(accuracy, feed_dict={x: img_batch,
y: label_batch,
Expand All @@ -201,7 +198,7 @@ class on any given dataset. Specify the configuration settings at the

# save checkpoint of the model
checkpoint_name = os.path.join(checkpoint_path,
'model_epoch'+str(epoch+1)+'.ckpt')
'model_epoch' + str(epoch + 1) + '.ckpt')
save_path = saver.save(sess, checkpoint_name)

print("{} Model checkpoint saved at {}".format(datetime.now(),
Expand Down