Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend accurancy on multiple channels labels #759

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions src/caffe/layers/accuracy_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "caffe/util/io.hpp"

namespace caffe {

template <typename Dtype>
void AccuracyLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
Expand All @@ -21,7 +20,8 @@ void AccuracyLayer<Dtype>::SetUp(
<< "The data and label should have the same number.";
CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
<< "top_k must be less than or equal to the number of classes.";
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_GE(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->channels() % bottom[1]->channels(), 0);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
(*top)[0]->Reshape(1, 1, 1, 1);
Expand All @@ -34,28 +34,38 @@ Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_label = bottom[1]->cpu_data();
int num = bottom[0]->num();
int dim = bottom[0]->count() / bottom[0]->num();
int channels = bottom[0]->channels();
int label_channels = bottom[1]->channels();
int data_chan_split = channels / bottom[1]->channels();;
vector<Dtype> maxval(top_k_+1);
vector<int> max_id(top_k_+1);
for (int i = 0; i < num; ++i) {
// Top-k accuracy
std::vector<std::pair<Dtype, int> > bottom_data_vector;
for (int j = 0; j < dim; ++j) {
bottom_data_vector.push_back(
std::make_pair(bottom_data[i * dim + j], j));
}
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second == static_cast<int>(bottom_label[i])) {
++accuracy;
break;
int intra_chan_accuracy = 0;
int data_step = 0;
for (int c = 0; c < label_channels; ++c) {
// Top-k accuracy
std::vector<std::pair<Dtype, int> > bottom_data_vector;
for (int j = 0; j < data_chan_split; ++j) {
bottom_data_vector.push_back(
std::make_pair(bottom_data[i * channels + data_step + j], j));
}
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second
== static_cast<int>(bottom_label[i * label_channels + c ])) {
++intra_chan_accuracy;
break;
}
}
data_step += data_chan_split;
}
if (intra_chan_accuracy == label_channels) {
++accuracy;
}
}

// LOG(INFO) << "Accuracy: " << accuracy;
(*top)[0]->mutable_cpu_data()[0] = accuracy / num;

Expand Down
60 changes: 42 additions & 18 deletions src/caffe/test/test_accuracy_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class AccuracyLayerTest : public ::testing::Test {
protected:
AccuracyLayerTest()
: blob_bottom_data_(new Blob<Dtype>(100, 10, 1, 1)),
blob_bottom_label_(new Blob<Dtype>(100, 1, 1, 1)),
blob_bottom_label_(new Blob<Dtype>(100, 2, 1, 1)),
blob_top_(new Blob<Dtype>()),
top_k_(3) {
// fill the probability values
Expand Down Expand Up @@ -90,16 +90,28 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPU) {
TypeParam max_value;
int max_id;
int num_correct_labels = 0;
int data_split = this->blob_bottom_data_->channels()/
this->blob_bottom_label_->channels();
for (int i = 0; i < 100; ++i) {
max_value = -FLT_MAX;
max_id = 0;
for (int j = 0; j < 10; ++j) {
if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) {
max_value = this->blob_bottom_data_->data_at(i, j, 0, 0);
max_id = j;
int split_offset = 0;
int num_chan_correct_labels = 0;
for (int c = 0; c < this->blob_bottom_label_->channels(); ++c) {
max_value = -FLT_MAX;
max_id = 0;
for (int j = 0; j < data_split; ++j) {
if (this->blob_bottom_data_->data_at(i, j +
split_offset, 0, 0) > max_value) {
max_value = this->blob_bottom_data_->data_at(i, j +
split_offset, 0, 0);
max_id = j;
}
}
if (max_id == this->blob_bottom_label_->data_at(i, c, 0, 0)) {
++num_chan_correct_labels;
}
split_offset += data_split;
}
if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) {
if (num_chan_correct_labels == this->blob_bottom_label_->channels()) {
++num_correct_labels;
}
}
Expand All @@ -118,19 +130,31 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) {
TypeParam current_value;
int current_rank;
int num_correct_labels = 0;
int data_split = this->blob_bottom_data_->channels()/
this->blob_bottom_label_->channels();
for (int i = 0; i < 100; ++i) {
for (int j = 0; j < 10; ++j) {
current_value = this->blob_bottom_data_->data_at(i, j, 0, 0);
current_rank = 0;
for (int k = 0; k < 10; ++k) {
if (this->blob_bottom_data_->data_at(i, k, 0, 0) > current_value) {
++current_rank;
int split_offset = 0;
int num_chan_correct_labels = 0;
for (int c = 0; c < this->blob_bottom_label_->channels(); ++c) {
for (int j = 0; j < data_split; ++j) {
current_value = this->blob_bottom_data_->data_at(i, j +
split_offset, 0, 0);
current_rank = 0;
for (int k = 0; k < data_split; ++k) {
if (this->blob_bottom_data_->data_at(i, k +
split_offset, 0, 0) > current_value) {
++current_rank;
}
}
if (current_rank < this->top_k_ &&
j == this->blob_bottom_label_->data_at(i, c, 0, 0)) {
++num_chan_correct_labels;
}
}
if (current_rank < this->top_k_ &&
j == this->blob_bottom_label_->data_at(i, 0, 0, 0)) {
++num_correct_labels;
}
split_offset += data_split;
}
if (num_chan_correct_labels == this->blob_bottom_label_->channels()) {
++num_correct_labels;
}
}

Expand Down