Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive loss layer for training siamese nets #959

Merged
merged 1 commit into from
Sep 19, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions examples/siamese/convert_mnist_siamese_data.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//
// This script converts the MNIST dataset to the leveldb format used
// by caffe to train siamese network.
// Usage:
// convert_mnist_data input_image_file input_label_file output_db_file
// The MNIST dataset could be downloaded at
// http://yann.lecun.com/exdb/mnist/
#include <fstream> // NOLINT(readability/streams)
#include <string>

#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "leveldb/db.h"
#include "stdint.h"

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/math_functions.hpp"

uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}

void read_image(std::ifstream* image_file, std::ifstream* label_file,
uint32_t index, uint32_t rows, uint32_t cols,
char* pixels, char* label) {
image_file->seekg(index * rows * cols + 16);
image_file->read(pixels, rows * cols);
label_file->seekg(index + 8);
label_file->read(label, 1);
}

void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_filename) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_file;
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;

image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);

// Open leveldb
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status = leveldb::DB::Open(
options, db_filename, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_filename
<< ". Is it already existing?";

char label_i;
char label_j;
char* pixels = new char[2 * rows * cols];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;

caffe::Datum datum;
datum.set_channels(2); // one channel for each image in the pair
datum.set_height(rows);
datum.set_width(cols);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int itemid = 0; itemid < num_items; ++itemid) {
int i = caffe::caffe_rng_rand() % num_items; // pick a random pair
int j = caffe::caffe_rng_rand() % num_items;
read_image(&image_file, &label_file, i, rows, cols,
pixels, &label_i);
read_image(&image_file, &label_file, j, rows, cols,
pixels + (rows * cols), &label_j);
datum.set_data(pixels, 2*rows*cols);
if (label_i == label_j) {
datum.set_label(1);
} else {
datum.set_label(0);
}
datum.SerializeToString(&value);
snprintf(key, kMaxKeyLength, "%08d", itemid);
db->Put(leveldb::WriteOptions(), std::string(key), value);
}

delete db;
delete pixels;
}

int main(int argc, char** argv) {
if (argc != 4) {
printf("This script converts the MNIST dataset to the leveldb format used\n"
"by caffe to train a siamese network.\n"
"Usage:\n"
" convert_mnist_data input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3]);
}
return 0;
}
21 changes: 21 additions & 0 deletions examples/siamese/create_mnist_siamese.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env sh
# This script converts the mnist data into leveldb format.

EXAMPLES=./build/examples/siamese
DATA=./data/mnist

echo "Creating leveldb..."

rm -rf ./examples/siamese/mnist_siamese_train_leveldb
rm -rf ./examples/siamese/mnist_siamese_test_leveldb

$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/train-images-idx3-ubyte \
$DATA/train-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_train_leveldb
$EXAMPLES/convert_mnist_siamese_data.bin \
$DATA/t10k-images-idx3-ubyte \
$DATA/t10k-labels-idx1-ubyte \
./examples/siamese/mnist_siamese_test_leveldb

echo "Done."
169 changes: 169 additions & 0 deletions examples/siamese/mnist_siamese.ipynb

Large diffs are not rendered by default.

95 changes: 95 additions & 0 deletions examples/siamese/mnist_siamese.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
name: "mnist_siamese"
input: "data"
input_dim: 10000
input_dim: 1
input_dim: 28
input_dim: 28

layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
}
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "pool1"
top: "conv2"
blobs_lr: 1
blobs_lr: 2
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
}
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layers {
name: "ip1"
type: INNER_PRODUCT
bottom: "pool2"
top: "ip1"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 500
}
}
layers {
name: "relu1"
type: RELU
bottom: "ip1"
top: "ip1"
}
layers {
name: "ip2"
type: INNER_PRODUCT
bottom: "ip1"
top: "ip2"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 10
}
}

layers {
name: "feat"
type: INNER_PRODUCT
bottom: "ip2"
top: "feat"
blobs_lr: 1
blobs_lr: 2
inner_product_param {
num_output: 2
}
}
25 changes: 25 additions & 0 deletions examples/siamese/mnist_siamese_solver.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# The train/test net protocol buffer definition
net: "examples/siamese/mnist_siamese_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0000
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 50000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/siamese/mnist_siamese"
# solver mode: CPU or GPU
solver_mode: GPU
Loading