Skip to content

Commit

Permalink
Fix AccuracyLayerTest for per-class accuracy.
Browse files Browse the repository at this point in the history
Fix AccuracyLayerTest for per-class accuracy. Previously in BVLC#2935, it crashes since the test accuracy is nan (0/0) when a class never appear.
  • Loading branch information
ronghanghu authored and acmiyaguchi committed Nov 13, 2017
1 parent 814ca0a commit 770bd3e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/caffe/test/test_accuracy_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) {

TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
AccuracyLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_per_class_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_per_class_vec_);
Expand Down Expand Up @@ -279,16 +278,16 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClass) {
EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0),
num_correct_labels / 100.0, 1e-4);
for (int i = 0; i < num_class; ++i) {
TypeParam accuracy_per_class = (num_per_class[i] > 0 ?
static_cast<TypeParam>(correct_per_class[i]) / num_per_class[i] : 0);
EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0),
static_cast<float>(correct_per_class[i]) / num_per_class[i],
1e-4);
accuracy_per_class, 1e-4);
}
}


TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
const TypeParam kIgnoreLabelValue = -1;
layer_param.mutable_accuracy_param()->set_ignore_label(kIgnoreLabelValue);
AccuracyLayer<TypeParam> layer(layer_param);
Expand Down Expand Up @@ -329,9 +328,10 @@ TYPED_TEST(AccuracyLayerTest, TestForwardCPUPerClassWithIgnoreLabel) {
EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0),
num_correct_labels / TypeParam(count), 1e-4);
for (int i = 0; i < 10; ++i) {
TypeParam accuracy_per_class = (num_per_class[i] > 0 ?
static_cast<TypeParam>(correct_per_class[i]) / num_per_class[i] : 0);
EXPECT_NEAR(this->blob_top_per_class_->data_at(i, 0, 0, 0),
TypeParam(correct_per_class[i]) / num_per_class[i],
1e-4);
accuracy_per_class, 1e-4);
}
}

Expand Down

0 comments on commit 770bd3e

Please sign in to comment.