Skip to content

Commit

Permalink
Merge pull request #20 from openai/batch_size
Browse files Browse the repository at this point in the history
Fix bugs when data size is not multiple of batch size
  • Loading branch information
Ian Goodfellow authored Sep 19, 2016
2 parents 757877b + 4a60b39 commit 92f3122
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cleverhans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def batch_indices(batch_nb, data_length, batch_size):
start -= shift
end -= shift

return start, end
return start, end
28 changes: 19 additions & 9 deletions cleverhans/utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
:param X_train: numpy array with training inputs
:param Y_train: numpy array with training outputs
:param save: Boolean controling the save operation
:param predictions_adv: if set with the adversarial example tensor,
will run adversarial training
:param predictions_adv: if set with the adversarial example tensor,
will run adversarial training
:return: True if model trained
"""
print "Starting model training using TensorFlow."
Expand All @@ -63,7 +63,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
print("Epoch " + str(epoch))

# Compute number of batches
nb_batches = int(math.ceil(len(X_train) / FLAGS.batch_size))
nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size))
assert nb_batches * FLAGS.batch_size >= len(X_train)

prev = time.time()
for batch in range(nb_batches):
Expand All @@ -80,6 +81,7 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
train_step.run(feed_dict={x: X_train[start:end],
y: Y_train[start:end],
keras.backend.learning_phase(): 1})
assert end >= len(X_train) # Check that all examples were used


if save:
Expand Down Expand Up @@ -112,21 +114,29 @@ def tf_model_eval(sess, x, y, model, X_test, Y_test):

with sess.as_default():
# Compute number of batches
nb_batches = int(math.ceil(len(X_test) / FLAGS.batch_size))
nb_batches = int(math.ceil(float(len(X_test)) / FLAGS.batch_size))
assert nb_batches * FLAGS.batch_size >= len(X_test)

for batch in range(nb_batches):
if batch % 100 == 0 and batch > 0:
print("Batch " + str(batch))

# Compute batch start and end indices
start, end = batch_indices(batch, len(X_test), FLAGS.batch_size)
# Must not use the `batch_indices` function here, because it
# repeats some examples.
# It's acceptable to repeat during training, but not eval.
start = batch * FLAGS.batch_size
end = min(len(X_test), start + FLAGS.batch_size)
cur_batch_size = end - start + 1

accuracy += acc_value.eval(feed_dict={x: X_test[start:end],
# The last batch may be smaller than all others, so we need to
# account for variable batch size here
accuracy += cur_batch_size * acc_value.eval(feed_dict={x: X_test[start:end],
y: Y_test[start:end],
keras.backend.learning_phase(): 0})
assert end >= len(X_test)

# Divide by number of batches to get final value
accuracy /= nb_batches
# Divide by number of examples to get final value
accuracy /= len(X_test)

return accuracy

Expand Down
8 changes: 4 additions & 4 deletions tests/test_mnist_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def main(argv=None):

# Train an MNIST model
tf_model_train(sess, x, y, predictions, X_train, Y_train)

# Evaluate the accuracy of the MNIST model on legitimate test examples
accuracy = tf_model_eval(sess, x, y, predictions, X_test, Y_test)
assert float(accuracy) >= 0.97
assert float(accuracy) >= 0.97, accuracy


if __name__ == '__main__':
app.run()

0 comments on commit 92f3122

Please sign in to comment.