Skip to content

Commit

Permalink
Splited concat_layer into .cpp and .cu, cleaned lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sguada committed Feb 28, 2014
1 parent eb56c5b commit 93afc15
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 110 deletions.
138 changes: 43 additions & 95 deletions src/caffe/layers/concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,147 +11,95 @@ namespace caffe {
template <typename Dtype>
void ConcatLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_GT(bottom.size(), 1) << "Concat Layer takes at least two blobs as input.";
CHECK_EQ(top->size(), 1) << "Concat Layer takes a single blob as output.";
CHECK_GT(bottom.size(), 1) <<
"Concat Layer takes at least two blobs as input.";
CHECK_EQ(top->size(), 1) <<
"Concat Layer takes a single blob as output.";
concat_dim_ = this->layer_param_.concat_dim();
CHECK_GE(concat_dim_,0) << "concat_dim should be >= 0";
CHECK_LE(concat_dim_,1) <<
CHECK_GE(concat_dim_, 0) << "concat_dim should be >= 0";
CHECK_LE(concat_dim_, 1) <<
"For now concat_dim <=1, it can only concat num and channels";
// Intialize with the first blob
COUNT_ = bottom[0]->count();
NUM_ = bottom[0]->num();
CHANNELS_ = bottom[0]->channels();
HEIGHT_ = bottom[0]->height();
WIDTH_ = bottom[0]->width();
for (int i=1; i<bottom.size(); ++i) {
WIDTH_ = bottom[0]->width();
for (int i = 1; i < bottom.size(); ++i) {
COUNT_ += bottom[i]->count();
if (concat_dim_==0) {
NUM_ += bottom[i]->num();
} else if (concat_dim_ == 1){
if (concat_dim_== 0) {
NUM_ += bottom[i]->num();
} else if (concat_dim_ == 1) {
CHANNELS_ += bottom[i]->channels();
} else if (concat_dim_ == 2) {
HEIGHT_ += bottom[i]->height();
} else if (concat_dim_ == 3) {
WIDTH_ += bottom[i]->width();
WIDTH_ += bottom[i]->width();
}
}
(*top)[0]->Reshape(NUM_, CHANNELS_, HEIGHT_, WIDTH_);
(*top)[0]->Reshape(NUM_, CHANNELS_, HEIGHT_, WIDTH_);
CHECK_EQ(COUNT_, (*top)[0]->count());
};
}

template <typename Dtype>
void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_cpu_data();
if (concat_dim_==0) {
int offset_num = 0;
for (int i=0; i<bottom.size(); ++i) {
if (concat_dim_== 0) {
int offset_num = 0;
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
int num_elem = bottom[i]->count();
caffe_copy(num_elem, bottom_data, top_data+(*top)[0]->offset(offset_num));
offset_num += bottom[i]->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i=0; i<bottom.size(); ++i) {
int offset_channel = 0;
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
int num_elem = bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n=0; n<NUM_; ++n){
int num_elem =
bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n = 0; n < NUM_; ++n) {
caffe_copy(num_elem, bottom_data+bottom[i]->offset(n),
top_data+(*top)[0]->offset(n,offset_channel));
}
top_data+(*top)[0]->offset(n, offset_channel));
}
offset_channel += bottom[i]->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ << " not implemented yet";
}
}

template <typename Dtype>
void ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_gpu_data();
if (concat_dim_==0) {
int offset_num = 0;
for (int i=0; i<bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
caffe_gpu_copy(bottom[i]->count(), bottom_data, top_data+(*top)[0]->offset(offset_num));
offset_num += bottom[i]->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i=0; i<bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
int num_elem = bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n=0; n<NUM_; ++n){
caffe_gpu_copy(num_elem, bottom_data+bottom[i]->offset(n),
top_data+(*top)[0]->offset(n,offset_channel));
}
offset_channel += bottom[i]->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ << " not implemented yet";
}
LOG(FATAL) << "concat_dim along dim" << concat_dim_ <<
" not implemented yet";
}
}

template <typename Dtype>
Dtype ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
if (concat_dim_==0) {
int offset_num = 0;
for (int i=0; i < bottom->size(); ++i) {
const Dtype* top_diff = top[0]->cpu_diff();
if (concat_dim_ == 0) {
int offset_num = 0;
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_cpu_diff();
caffe_copy(blob->count(), top_diff+top[0]->offset(offset_num),bottom_diff);
caffe_copy(blob->count(),
top_diff+top[0]->offset(offset_num), bottom_diff);
offset_num += blob->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i=0; i < bottom->size(); ++i) {
int offset_channel = 0;
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_cpu_diff();
int num_elem = blob->channels()*blob->height()*blob->width();
for (int n=0; n<NUM_; ++n){
caffe_copy(num_elem, top_diff+top[0]->offset(n,offset_channel),
bottom_diff+blob->offset(n));
}
for (int n = 0; n < NUM_; ++n) {
caffe_copy(num_elem, top_diff+top[0]->offset(n, offset_channel),
bottom_diff+blob->offset(n));
}
offset_channel += blob->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ << " not implemented yet";
}
return Dtype(0.);
}


template <typename Dtype>
Dtype ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
if (concat_dim_==0) {
int offset_num = 0;
for (int i=0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_gpu_diff();
caffe_gpu_copy(blob->count(), top_diff+top[0]->offset(offset_num),bottom_diff);
offset_num += blob->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i=0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_gpu_diff();
int num_elem = blob->channels()*blob->height()*blob->width();
for (int n=0; n<NUM_; ++n){
caffe_gpu_copy(num_elem, top_diff+top[0]->offset(n,offset_channel),
bottom_diff+blob->offset(n));
}
offset_channel += blob->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ << " not implemented yet";
}
LOG(FATAL) << "concat_dim along dim" << concat_dim_ <<
" not implemented yet";
}
return Dtype(0.);
}

Expand Down
75 changes: 75 additions & 0 deletions src/caffe/layers/concat_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2014 Sergio Guadarrama

#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_gpu_data();
if (concat_dim_ == 0) {
int offset_num = 0;
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
caffe_gpu_copy(bottom[i]->count(), bottom_data,
top_data+(*top)[0]->offset(offset_num));
offset_num += bottom[i]->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
int num_elem =
bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_copy(num_elem, bottom_data+bottom[i]->offset(n),
top_data+(*top)[0]->offset(n, offset_channel));
}
offset_channel += bottom[i]->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ <<
" not implemented yet";
}
}

template <typename Dtype>
Dtype ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
if (concat_dim_ == 0) {
int offset_num = 0;
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_gpu_diff();
caffe_gpu_copy(blob->count(),
top_diff+top[0]->offset(offset_num), bottom_diff);
offset_num += blob->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_gpu_diff();
int num_elem = blob->channels()*blob->height()*blob->width();
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_copy(num_elem, top_diff+top[0]->offset(n, offset_channel),
bottom_diff+blob->offset(n));
}
offset_channel += blob->channels();
}
} else {
LOG(FATAL) << "concat_dim along dim" << concat_dim_ <<
" not implemented yet";
}
return Dtype(0.);
}

INSTANTIATE_CLASS(ConcatLayer);

} // namespace caffe
31 changes: 16 additions & 15 deletions src/caffe/test/test_concat_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright 2014 Sergio Guadarrama

#include <cstring>
#include <cuda_runtime.h>
#include <vector>

#include "cuda_runtime.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
Expand All @@ -23,7 +24,7 @@ class ConcatLayerTest : public ::testing::Test {
: blob_bottom_0(new Blob<Dtype>(2, 3, 6, 5)),
blob_bottom_1(new Blob<Dtype>(2, 5, 6, 5)),
blob_bottom_2(new Blob<Dtype>(5, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {};
blob_top_(new Blob<Dtype>()) {}
virtual void SetUp() {
// fill the values
FillerParameter filler_param;
Expand All @@ -39,17 +40,18 @@ class ConcatLayerTest : public ::testing::Test {
blob_bottom_vec_1.push_back(blob_bottom_0);
blob_bottom_vec_1.push_back(blob_bottom_2);
blob_top_vec_.push_back(blob_top_);
};
}

virtual ~ConcatLayerTest() {
delete blob_bottom_0; delete blob_bottom_1; delete blob_bottom_2; delete blob_top_;
delete blob_bottom_0; delete blob_bottom_1;
delete blob_bottom_2; delete blob_top_;
}

Blob<Dtype>* const blob_bottom_0;
Blob<Dtype>* const blob_bottom_1;
Blob<Dtype>* const blob_bottom_2;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_0,blob_bottom_vec_1;
vector<Blob<Dtype>*> blob_bottom_vec_0, blob_bottom_vec_1;
vector<Blob<Dtype>*> blob_top_vec_;
};

Expand All @@ -61,7 +63,8 @@ TYPED_TEST(ConcatLayerTest, TestSetupNum) {
layer_param.set_concat_dim(0);
ConcatLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_1, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_0->num()+this->blob_bottom_2->num());
EXPECT_EQ(this->blob_top_->num(),
this->blob_bottom_0->num() + this->blob_bottom_2->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_0->channels());
EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_0->height());
EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_0->width());
Expand All @@ -72,7 +75,8 @@ TYPED_TEST(ConcatLayerTest, TestSetupChannels) {
ConcatLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_0, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_0->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_0->channels()+this->blob_bottom_1->channels());
EXPECT_EQ(this->blob_top_->channels(),
this->blob_bottom_0->channels()+this->blob_bottom_1->channels());
EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_0->height());
EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_0->width());
}
Expand All @@ -88,14 +92,16 @@ TYPED_TEST(ConcatLayerTest, TestCPUNum) {
for (int c = 0; c < this->blob_bottom_0->channels(); ++c) {
for (int h = 0; h < this->blob_top_->height(); ++h) {
for (int w = 0; w < this->blob_top_->width(); ++w) {
EXPECT_EQ(this->blob_top_->data_at(n, c, h, w), this->blob_bottom_vec_0[0]->data_at(n, c, h, w));
EXPECT_EQ(this->blob_top_->data_at(n, c, h, w),
this->blob_bottom_vec_0[0]->data_at(n, c, h, w));
}
}
}
for (int c = 0; c < this->blob_bottom_1->channels(); ++c) {
for (int h = 0; h < this->blob_top_->height(); ++h) {
for (int w = 0; w < this->blob_top_->width(); ++w) {
EXPECT_EQ(this->blob_top_->data_at(n, c+3, h, w), this->blob_bottom_vec_0[1]->data_at(n, c, h, w));
EXPECT_EQ(this->blob_top_->data_at(n, c+3, h, w),
this->blob_bottom_vec_0[1]->data_at(n, c, h, w));
}
}
}
Expand All @@ -108,8 +114,6 @@ TYPED_TEST(ConcatLayerTest, TestCPUGradient) {
Caffe::set_mode(Caffe::CPU);
ConcatLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
// it is too expensive to call curand multiple times, so we don't do an
// exhaustive gradient check.
checker.CheckGradient(&layer, &(this->blob_bottom_vec_0),
&(this->blob_top_vec_));
}
Expand All @@ -119,11 +123,8 @@ TYPED_TEST(ConcatLayerTest, TestGPUGradient) {
Caffe::set_mode(Caffe::GPU);
ConcatLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-3);
// it is too expensive to call curand multiple times, so we don't do an
// exhaustive gradient check.
checker.CheckGradient(&layer, &(this->blob_bottom_vec_0),
&(this->blob_top_vec_));
}


}
} // namespace caffe

0 comments on commit 93afc15

Please sign in to comment.