diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 7c913114abc..4de7c8bd335 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -11,7 +11,6 @@ #include "caffe/util/io.hpp" namespace caffe { - template void AccuracyLayer::SetUp( const vector*>& bottom, vector*>* top) { @@ -21,7 +20,8 @@ void AccuracyLayer::SetUp( << "The data and label should have the same number."; CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) << "top_k must be less than or equal to the number of classes."; - CHECK_EQ(bottom[1]->channels(), 1); + CHECK_GE(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->channels() % bottom[1]->channels(), 0); CHECK_EQ(bottom[1]->height(), 1); CHECK_EQ(bottom[1]->width(), 1); (*top)[0]->Reshape(1, 1, 1, 1); @@ -34,28 +34,38 @@ Dtype AccuracyLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->cpu_data(); const Dtype* bottom_label = bottom[1]->cpu_data(); int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); + int channels = bottom[0]->channels(); + int label_channels = bottom[1]->channels(); + int data_chan_split = channels / bottom[1]->channels();; vector maxval(top_k_+1); vector max_id(top_k_+1); for (int i = 0; i < num; ++i) { - // Top-k accuracy - std::vector > bottom_data_vector; - for (int j = 0; j < dim; ++j) { - bottom_data_vector.push_back( - std::make_pair(bottom_data[i * dim + j], j)); - } - std::partial_sort( - bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, - bottom_data_vector.end(), std::greater >()); - // check if true label is in top k predictions - for (int k = 0; k < top_k_; k++) { - if (bottom_data_vector[k].second == static_cast(bottom_label[i])) { - ++accuracy; - break; + int intra_chan_accuracy = 0; + int data_step = 0; + for (int c = 0; c < label_channels; ++c) { + // Top-k accuracy + std::vector > bottom_data_vector; + for (int j = 0; j < data_chan_split; ++j) { + bottom_data_vector.push_back( + std::make_pair(bottom_data[i * channels + data_step + j], j)); + } + std::partial_sort( + bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, + bottom_data_vector.end(), std::greater >()); + // check if true label is in top k predictions + for (int k = 0; k < top_k_; k++) { + if (bottom_data_vector[k].second + == static_cast(bottom_label[i * label_channels + c ])) { + ++intra_chan_accuracy; + break; + } } + data_step += data_chan_split; + } + if (intra_chan_accuracy == label_channels) { + ++accuracy; } } - // LOG(INFO) << "Accuracy: " << accuracy; (*top)[0]->mutable_cpu_data()[0] = accuracy / num; diff --git a/src/caffe/test/test_accuracy_layer.cpp b/src/caffe/test/test_accuracy_layer.cpp index 5023a809d17..57aa09953f5 100644 --- a/src/caffe/test/test_accuracy_layer.cpp +++ b/src/caffe/test/test_accuracy_layer.cpp @@ -21,7 +21,7 @@ class AccuracyLayerTest : public ::testing::Test { protected: AccuracyLayerTest() : blob_bottom_data_(new Blob(100, 10, 1, 1)), - blob_bottom_label_(new Blob(100, 1, 1, 1)), + blob_bottom_label_(new Blob(100, 2, 1, 1)), blob_top_(new Blob()), top_k_(3) { // fill the probability values @@ -90,16 +90,28 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPU) { TypeParam max_value; int max_id; int num_correct_labels = 0; + int data_split = this->blob_bottom_data_->channels()/ + this->blob_bottom_label_->channels(); for (int i = 0; i < 100; ++i) { - max_value = -FLT_MAX; - max_id = 0; - for (int j = 0; j < 10; ++j) { - if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { - max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - max_id = j; + int split_offset = 0; + int num_chan_correct_labels = 0; + for (int c = 0; c < this->blob_bottom_label_->channels(); ++c) { + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < data_split; ++j) { + if (this->blob_bottom_data_->data_at(i, j + + split_offset, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j + + split_offset, 0, 0); + max_id = j; + } + } + if (max_id == this->blob_bottom_label_->data_at(i, c, 0, 0)) { + ++num_chan_correct_labels; } + split_offset += data_split; } - if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + if (num_chan_correct_labels == this->blob_bottom_label_->channels()) { ++num_correct_labels; } } @@ -118,19 +130,31 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) { TypeParam current_value; int current_rank; int num_correct_labels = 0; + int data_split = this->blob_bottom_data_->channels()/ + this->blob_bottom_label_->channels(); for (int i = 0; i < 100; ++i) { - for (int j = 0; j < 10; ++j) { - current_value = this->blob_bottom_data_->data_at(i, j, 0, 0); - current_rank = 0; - for (int k = 0; k < 10; ++k) { - if (this->blob_bottom_data_->data_at(i, k, 0, 0) > current_value) { - ++current_rank; + int split_offset = 0; + int num_chan_correct_labels = 0; + for (int c = 0; c < this->blob_bottom_label_->channels(); ++c) { + for (int j = 0; j < data_split; ++j) { + current_value = this->blob_bottom_data_->data_at(i, j + + split_offset, 0, 0); + current_rank = 0; + for (int k = 0; k < data_split; ++k) { + if (this->blob_bottom_data_->data_at(i, k + + split_offset, 0, 0) > current_value) { + ++current_rank; + } + } + if (current_rank < this->top_k_ && + j == this->blob_bottom_label_->data_at(i, c, 0, 0)) { + ++num_chan_correct_labels; } } - if (current_rank < this->top_k_ && - j == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { - ++num_correct_labels; - } + split_offset += data_split; + } + if (num_chan_correct_labels == this->blob_bottom_label_->channels()) { + ++num_correct_labels; } }