Skip to content

Commit

Permalink
Updating crop to work for ND Blobs.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlGene committed Jan 19, 2016
1 parent 0af80f0 commit c8ecb5e
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 20 deletions.
8 changes: 7 additions & 1 deletion include/caffe/layers/crop_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ class CropLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int crop_h_, crop_w_;
vector<int> offsets;
private:
void crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
vector<int>& offsets,
vector<int>& indices,
int cur_dim);
};

} // namespace caffe
Expand Down
101 changes: 85 additions & 16 deletions src/caffe/layers/crop_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <vector>


#include "caffe/layer.hpp"
#include "caffe/layers/crop_layer.hpp"
#include "caffe/net.hpp"
Expand All @@ -17,43 +19,110 @@ void CropLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_EQ(bottom.size(), 2) << "Wrong number of bottom blobs.";
CHECK_EQ(bottom[0]->num_axes(), 4) << "Only works with 4D blobs.";
CHECK_EQ(bottom[1]->num_axes(), 4) << "Only works with 4D blobs.";
crop_h_ = param.offset_height();
crop_w_ = param.offset_width();
// Move to Reshape
if(param.has_offsets() ) {
//CHECK_EQ(param.offsets().dim_size(),top[0]->num_axes()) << "shape missmatch";
for(int i=0;i<param.offsets().dim_size();i++) {
offsets.push_back(param.offsets().dim(i));
}
} else {
for(int i=0;i<top[0]->num_axes();i++) {
offsets.push_back(0);
}
}
}

template <typename Dtype>
void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Check that the image we are cropping minus the margin is bigger than the
// destination image.
CHECK_GT(bottom[0]->height()-crop_h_, bottom[1]->height())
<< "invalid offset";
CHECK_GT(bottom[0]->width()-crop_w_, bottom[1]->width()) << "invalid offset";
//CHECK_GT(bottom[0]->height()-crop_h_, bottom[1]->height())
// << "invalid offset";
//CHECK_GT(bottom[0]->width()-crop_w_, bottom[1]->width()) << "invalid offset";
top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[1]->height(),
bottom[1]->width());
}



// Vector simple addition, using first argument as size
std::vector<int> vec_add(const std::vector<int> a, const std::vector<int> b)
{
assert(a.size() <= b.size());
std::vector<int> result;
result.reserve(a.size());

//std::transform(a.begin(), a.end(), b.begin(),
// std::back_inserter(result), std::plus<int>());
for(int i=0;i<a.size();i++) {
result.push_back(a[i] + b[i]);
}
return result;
}

// recursive copy function
template <typename Dtype>
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
for (int n = 0; n < top[0]->num(); ++n) {
for (int c = 0; c < top[0]->channels(); ++c) {
for (int h = 0; h < top[0]->height(); ++h) {
caffe_copy(top[0]->width(),
bottom_data + bottom[0]->offset(n, c, crop_h_ + h, crop_w_),
top_data + top[0]->offset(n, c, h));
}
void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
vector<int>& offsets,
vector<int>& indices,
int cur_dim) {
if( cur_dim + 1 < top[0]->num_axes() ) {
// We are not yet at the final dimension, call copy recursivley
for(int i=0;i<top[0]->shape(cur_dim);++i) {
indices[cur_dim] = i;
crop_copy(bottom,top,offsets,indices,cur_dim+1);
}
} else {
// We are at the last dimensions, which is stored continously in memory
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
for(int i=0;i<top[0]->shape(cur_dim);++i) {
indices[cur_dim] = i;

vector<int> ired = vector<int>(indices);
ired.resize(cur_dim); // which is top[0]->num_axes() - 1

vector<int> ipluso = vec_add(ired,offsets);
ipluso.push_back(offsets[cur_dim]);

LOG(INFO) << "copy: " << cur_dim << " "
<< top[0]->shape(cur_dim) << " ";
LOG(INFO) << ipluso[0] << " " << ipluso[1] << " " << ipluso[2] << " " << ipluso[3];
LOG(INFO) << ired[0] << " " << ired[1] << " " << ired[2];
LOG(INFO);
caffe_copy( top[0]->shape(cur_dim),
bottom_data + bottom[0]->offset( ipluso ),
top_data + top[0]->offset( ired ) );
}
}
}

template <typename Dtype>
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Replace this with conversion from params
//std::vector<int> offsets(4,0);
//offsets[0] = 0;
//offsets[1] = 0;
//offsets[2] = crop_h_;
//offsets[3] = crop_w_;

//LOG(INFO) << offsets[0] << " " << offsets[1] << " " << offsets[2] << " " << offsets[3];
//LOG(FATAL);
std::vector<int> indices(top[0]->num_axes(), 0);
crop_copy(bottom,top,offsets,indices,0);

}

template <typename Dtype>
void CropLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
int crop_h_ = offsets[2];
int crop_w_ = offsets[3];
if (propagate_down[0]) {
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
for (int n = 0; n < top[0]->num(); ++n) {
Expand Down
4 changes: 4 additions & 0 deletions src/caffe/layers/crop_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int lines = top[0]->count() / top[0]->width();
int crop_h_ = offsets[2];
int crop_w_ = offsets[3];

// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
Expand All @@ -43,6 +45,8 @@ void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int lines = top[0]->count() / top[0]->width();
int crop_h_ = offsets[2];
int crop_w_ = offsets[3];

if (propagate_down[0]) {
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
Expand Down
4 changes: 1 addition & 3 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,7 @@ message ConvolutionParameter {

message CropParameter {
// Assumes standard dimensions: ( N,C,H,W )
// This could possibly be extended to use "optional BlobShape offsets"
optional uint32 offset_height = 1[default = 0];
optional uint32 offset_width = 2[default = 0];
optional BlobShape offsets = 1;
}

message DataParameter {
Expand Down

0 comments on commit c8ecb5e

Please sign in to comment.