Skip to content

Commit

Permalink
basic tests (Forward, Gradient) for ReshapeLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Mar 27, 2015
1 parent 66fbc87 commit 2f7df49
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions src/caffe/test/test_reshape_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,61 @@ TYPED_TEST(ReshapeLayerTest, TestInferenceOfUnspecified) {
EXPECT_EQ(this->blob_top_->width(), 3);
}

TYPED_TEST(ReshapeLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
BlobShape* shape = layer_param.mutable_reshape_param()->mutable_shape();
shape->add_dim(6);
shape->add_dim(2);
shape->add_dim(3);
shape->add_dim(5);
ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_EQ(this->blob_top_->cpu_data()[i],
this->blob_bottom_->cpu_data()[i]);
}
}

TYPED_TEST(ReshapeLayerTest, TestForwardAfterReshape) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
BlobShape* shape = layer_param.mutable_reshape_param()->mutable_shape();
shape->add_dim(6);
shape->add_dim(2);
shape->add_dim(3);
shape->add_dim(5);
ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// We know the above produced the correct result from TestForward.
// Reshape the bottom and call layer.Reshape, then try again.
vector<int> new_bottom_shape(1, 2 * 3 * 6 * 5);
this->blob_bottom_->Reshape(new_bottom_shape);
layer.Reshape(this->blob_bottom_vec_, this->blob_top_vec_);
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_EQ(this->blob_top_->cpu_data()[i],
this->blob_bottom_->cpu_data()[i]);
}
}

TYPED_TEST(ReshapeLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
BlobShape* shape = layer_param.mutable_reshape_param()->mutable_shape();
shape->add_dim(6);
shape->add_dim(2);
shape->add_dim(3);
shape->add_dim(5);
ReshapeLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}

} // namespace caffe

0 comments on commit 2f7df49

Please sign in to comment.