Skip to content

Commit

Permalink
Add ReshapeParameter axis and num_axes to reshape only a particular span
Browse files Browse the repository at this point in the history
of the input shape
  • Loading branch information
jeffdonahue committed May 14, 2015
1 parent 9d6827f commit 67266d0
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 14 deletions.
2 changes: 0 additions & 2 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ class ReshapeLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {}

/// @brief the current output shape
vector<int> top_shape_;
/// @brief vector of axes indices whose dimensions we'll copy from the bottom
vector<int> copy_axes_;
/// @brief the index of the axis whose dimension we infer, or -1 if none
Expand Down
53 changes: 42 additions & 11 deletions src/caffe/layers/reshape_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,77 @@ void ReshapeLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
copy_axes_.clear();
const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
const int top_num_axes = top_blob_shape.dim_size();
top_shape_.resize(top_num_axes);
constant_count_ = 1;
for (int i = 0; i < top_num_axes; ++i) {
top_shape_[i] = top_blob_shape.dim(i);
if (top_shape_[i] == 0) {
const int top_dim = top_blob_shape.dim(i);
if (top_dim == 0) {
copy_axes_.push_back(i);
} else if (top_shape_[i] == -1) {
} else if (top_dim == -1) {
CHECK_EQ(inferred_axis_, -1) << "new shape contains multiple "
<< "-1 dims; at most a single (1) value of -1 may be specified";
inferred_axis_ = i;
} else {
constant_count_ *= top_shape_[i];
constant_count_ *= top_dim;
}
}
}

template <typename Dtype>
void ReshapeLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int input_start_axis = this->layer_param_.reshape_param().axis();
const int start_axis = (input_start_axis >= 0) ? input_start_axis :
bottom[0]->num_axes() + input_start_axis + 1;
CHECK_GE(start_axis, 0) << "axis " << input_start_axis << " out of range";
CHECK_LE(start_axis, bottom[0]->num_axes()) << "axis " << input_start_axis
<< " out of range for " << bottom[0]->num_axes() << "-D input blob";
const int num_axes = this->layer_param_.reshape_param().num_axes();
CHECK_GE(num_axes, -1) << "num_axes must be >= 0, or -1 for all";
const int end_axis =
(num_axes == -1) ? bottom[0]->num_axes() : (start_axis + num_axes);
CHECK_LE(end_axis, bottom[0]->num_axes())
<< "end_axis = axis + num_axes is out of range";
const int num_axes_replaced = end_axis - start_axis;
const int num_axes_retained = bottom[0]->num_axes() - num_axes_replaced;
const BlobShape& top_blob_shape = this->layer_param_.reshape_param().shape();
const int num_new_axes = top_blob_shape.dim_size();
vector<int> top_shape(num_axes_retained + num_new_axes);
int top_shape_index = 0;
for (int i = 0; i < start_axis; ++i) {
top_shape[top_shape_index++] = bottom[0]->shape(i);
}
for (int i = 0; i < num_new_axes; ++i) {
top_shape[top_shape_index++] = top_blob_shape.dim(i);
}
for (int i = end_axis; i < bottom[0]->num_axes(); ++i) {
top_shape[top_shape_index++] = bottom[0]->shape(i);
}
CHECK_EQ(top_shape_index, top_shape.size());
for (int i = 0; i < copy_axes_.size(); ++i) {
const int copy_axis_index = copy_axes_[i];
CHECK_GT(bottom[0]->num_axes(), copy_axis_index) << "new shape contains "
<< "a 0, but there is no corresponding bottom axis to copy";
top_shape_[copy_axis_index] = bottom[0]->shape(copy_axis_index);
CHECK_GT(bottom[0]->num_axes(), start_axis + copy_axis_index)
<< "new shape contains a 0, but there was no corresponding bottom axis "
<< "to copy";
top_shape[start_axis + copy_axis_index] =
bottom[0]->shape(start_axis + copy_axis_index);
}
if (inferred_axis_ >= 0) {
// A -1 dim was specified; infer the correct dimension by computing the
// product of the other dimensions.
int explicit_count = constant_count_;
explicit_count *= bottom[0]->count(0, start_axis);
explicit_count *= bottom[0]->count(end_axis);
for (int i = 0; i < copy_axes_.size(); ++i) {
const int copy_axis_index = copy_axes_[i];
explicit_count *= top_shape_[copy_axis_index];
explicit_count *= top_shape[start_axis + copy_axis_index];
}
CHECK_EQ(0, bottom[0]->count() % explicit_count) << "bottom count ("
<< bottom[0]->count() << ") must be divisible by the product of "
<< "the specified dimensions (" << explicit_count << ")";
const int inferred_dim = bottom[0]->count() / explicit_count;
top_shape_[inferred_axis_] = inferred_dim;
top_shape[start_axis + inferred_axis_] = inferred_dim;
}
top[0]->Reshape(top_shape_);
top[0]->Reshape(top_shape);
CHECK_EQ(top[0]->count(), bottom[0]->count())
<< "output count must match input count";
top[0]->ShareData(*bottom[0]);
Expand Down
58 changes: 57 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,64 @@ message ReshapeParameter {
// Specify the output dimensions. If some of the dimensions are set to 0,
// the corresponding dimension from the bottom layer is used (unchanged).
// Exactly one dimension may be set to -1, in which case its value is
// inferred from the count of the bottom layer and the remaining dimensions.
// inferred from the count of the bottom blob and the remaining dimensions.
// For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
//
// layer {
// type: "Reshape" bottom: "input" top: "output"
// reshape_param { ... }
// }
//
// If "input" is 2D with shape 2 x 8, then the following reshape_param
// specifications are all equivalent, producing a 3D blob "output" with shape
// 2 x 2 x 4:
//
// reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 0 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 0 dim: 2 dim: -1 } }
// reshape_param { shape { dim: -1 dim: 0 dim: 2 } }
//
optional BlobShape shape = 1;

// axis and num_axes control the portion of the bottom blob's shape that are
// replaced by (included in) the reshape. By default (axis == 0 and
// num_axes == -1), the entire bottom blob shape is included in the reshape,
// and hence the shape field must specify the entire output shape.
//
// axis may be non-zero to retain some portion of the beginning of the input
// shape (and may be negative to index from the end; e.g., -1 to begin the
// reshape after the last axis, including nothing in the reshape,
// -2 to include only the last axis, etc.).
//
// For example, suppose "input" is a 2D blob with shape 2 x 8.
// Then the following ReshapeLayer specifications are all equivalent,
// producing a blob "output" with shape 2 x 2 x 4:
//
// reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 2 dim: 4 } axis: 1 }
// reshape_param { shape { dim: 2 dim: 4 } axis: -3 }
//
// num_axes specifies the extent of the reshape.
// If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
// input axes in the range [axis, axis+num_axes].
// num_axes may also be -1, the default, to include all remaining axes
// (starting from axis).
//
// For example, suppose "input" is a 2D blob with shape 2 x 8.
// Then the following ReshapeLayer specifications are equivalent,
// producing a blob "output" with shape 1 x 2 x 8.
//
// reshape_param { shape { dim: 1 dim: 2 dim: 8 } }
// reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 }
// reshape_param { shape { dim: 1 } num_axes: 0 }
//
// On the other hand, these would produce output blob shape 2 x 1 x 8:
//
// reshape_param { shape { dim: 2 dim: 1 dim: 8 } }
// reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 }
//
optional int32 axis = 2 [default = 0];
optional int32 num_axes = 3 [default = -1];
}

// Message that stores parameters used by SigmoidLayer
Expand Down
105 changes: 105 additions & 0 deletions src/caffe/test/test_reshape_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,111 @@ TYPED_TEST(ReshapeLayerTest, TestInferenceOfUnspecified) {
EXPECT_EQ(this->blob_top_->width(), 3);
}

TYPED_TEST(ReshapeLayerTest, TestInferenceOfUnspecifiedWithStartAxis) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_reshape_param()->set_axis(1);
BlobShape* blob_shape = layer_param.mutable_reshape_param()->mutable_shape();
blob_shape->add_dim(3);
blob_shape->add_dim(10);
blob_shape->add_dim(-1);

ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);

ASSERT_EQ(this->blob_top_->num_axes(), 4);
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3);
EXPECT_EQ(this->blob_top_->height(), 10);
EXPECT_EQ(this->blob_top_->width(), 3);
}

TYPED_TEST(ReshapeLayerTest, TestInsertSingletonAxesStart) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_reshape_param()->set_axis(0);
layer_param.mutable_reshape_param()->set_num_axes(0);
BlobShape* blob_shape = layer_param.mutable_reshape_param()->mutable_shape();
blob_shape->add_dim(1);
blob_shape->add_dim(1);
blob_shape->add_dim(1);

ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);

ASSERT_EQ(this->blob_top_->num_axes(), 7);
EXPECT_EQ(this->blob_top_->shape(0), 1);
EXPECT_EQ(this->blob_top_->shape(1), 1);
EXPECT_EQ(this->blob_top_->shape(2), 1);
EXPECT_EQ(this->blob_top_->shape(3), 2);
EXPECT_EQ(this->blob_top_->shape(4), 3);
EXPECT_EQ(this->blob_top_->shape(5), 6);
EXPECT_EQ(this->blob_top_->shape(6), 5);
}

TYPED_TEST(ReshapeLayerTest, TestInsertSingletonAxesMiddle) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_reshape_param()->set_axis(2);
layer_param.mutable_reshape_param()->set_num_axes(0);
BlobShape* blob_shape = layer_param.mutable_reshape_param()->mutable_shape();
blob_shape->add_dim(1);
blob_shape->add_dim(1);
blob_shape->add_dim(1);

ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);

ASSERT_EQ(this->blob_top_->num_axes(), 7);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3);
EXPECT_EQ(this->blob_top_->shape(2), 1);
EXPECT_EQ(this->blob_top_->shape(3), 1);
EXPECT_EQ(this->blob_top_->shape(4), 1);
EXPECT_EQ(this->blob_top_->shape(5), 6);
EXPECT_EQ(this->blob_top_->shape(6), 5);
}

TYPED_TEST(ReshapeLayerTest, TestInsertSingletonAxesEnd) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_reshape_param()->set_axis(-1);
layer_param.mutable_reshape_param()->set_num_axes(0);
BlobShape* blob_shape = layer_param.mutable_reshape_param()->mutable_shape();
blob_shape->add_dim(1);
blob_shape->add_dim(1);
blob_shape->add_dim(1);

ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);

ASSERT_EQ(this->blob_top_->num_axes(), 7);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3);
EXPECT_EQ(this->blob_top_->shape(2), 6);
EXPECT_EQ(this->blob_top_->shape(3), 5);
EXPECT_EQ(this->blob_top_->shape(4), 1);
EXPECT_EQ(this->blob_top_->shape(5), 1);
EXPECT_EQ(this->blob_top_->shape(6), 1);
}

TYPED_TEST(ReshapeLayerTest, TestFlattenMiddle) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_reshape_param()->set_axis(1);
layer_param.mutable_reshape_param()->set_num_axes(2);
BlobShape* blob_shape = layer_param.mutable_reshape_param()->mutable_shape();
blob_shape->add_dim(-1);

ReshapeLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);

ASSERT_EQ(this->blob_top_->num_axes(), 3);
EXPECT_EQ(this->blob_top_->shape(0), 2);
EXPECT_EQ(this->blob_top_->shape(1), 3 * 6);
EXPECT_EQ(this->blob_top_->shape(2), 5);
}

TYPED_TEST(ReshapeLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand Down

0 comments on commit 67266d0

Please sign in to comment.