Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-square cropping #1980

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
Phase phase)
: param_(param), phase_(phase) {
CHECK(param_.crop_size() == 0 ||
(param_.crop_h() == 0 && param_.crop_w() == 0))
<< "Crop size is crop_size OR crop_h and crop_w; not both";
CHECK((param_.crop_h() != 0) == (param_.crop_w() != 0))
<< "For non-square crops both crop_h and crop_w are required.";
// check if we want to use mean_file
if (param_.has_mean_file()) {
CHECK_EQ(param_.mean_value_size(), 0) <<
Expand Down Expand Up @@ -43,15 +48,20 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
const int datum_width = datum.width();

const int crop_size = param_.crop_size();
int crop_h = param_.crop_h();
int crop_w = param_.crop_w();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_uint8 = data.size() > 0;
const bool has_mean_values = mean_values_.size() > 0;
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}

CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
CHECK_GE(datum_height, crop_h);
CHECK_GE(datum_width, crop_w);

Dtype* mean = NULL;
if (has_mean_file) {
Expand All @@ -76,16 +86,16 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,

int h_off = 0;
int w_off = 0;
if (crop_size) {
height = crop_size;
width = crop_size;
if (crop_h > 0 || crop_w > 0) {
height = crop_h;
width = crop_w;
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(datum_height - crop_size + 1);
w_off = Rand(datum_width - crop_size + 1);
h_off = Rand(datum_height - crop_h + 1);
w_off = Rand(datum_width - crop_w + 1);
} else {
h_off = (datum_height - crop_size) / 2;
w_off = (datum_width - crop_size) / 2;
h_off = (datum_height - crop_h) / 2;
w_off = (datum_width - crop_w) / 2;
}
}

Expand Down Expand Up @@ -140,10 +150,14 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
CHECK_GE(num, 1);

const int crop_size = param_.crop_size();

if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
int crop_h = param_.crop_h();
int crop_w = param_.crop_w();
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}
if (crop_h > 0 || crop_w > 0) {
CHECK_EQ(crop_h, height);
CHECK_EQ(crop_w, width);
} else {
CHECK_EQ(datum_height, height);
CHECK_EQ(datum_width, width);
Expand Down Expand Up @@ -213,14 +227,19 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

const int crop_size = param_.crop_size();
int crop_h = param_.crop_h();
int crop_w = param_.crop_w();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}

CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
CHECK_GE(img_height, crop_h);
CHECK_GE(img_width, crop_w);

Dtype* mean = NULL;
if (has_mean_file) {
Expand All @@ -243,18 +262,18 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
int h_off = 0;
int w_off = 0;
cv::Mat cv_cropped_img = cv_img;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
if (crop_h > 0 || crop_w > 0) {
CHECK_EQ(crop_h, height);
CHECK_EQ(crop_w, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(img_height - crop_size + 1);
w_off = Rand(img_width - crop_size + 1);
h_off = Rand(img_height - crop_h + 1);
w_off = Rand(img_width - crop_w + 1);
} else {
h_off = (img_height - crop_size) / 2;
w_off = (img_width - crop_size) / 2;
h_off = (img_height - crop_h) / 2;
w_off = (img_width - crop_w) / 2;
}
cv::Rect roi(w_off, h_off, crop_size, crop_size);
cv::Rect roi(w_off, h_off, crop_h, crop_w);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the crop_h and crop_w are swapped.

cv_cropped_img = cv_img(roi);
} else {
CHECK_EQ(img_height, height);
Expand Down Expand Up @@ -314,23 +333,28 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
CHECK_GE(input_width, width);

const int crop_size = param_.crop_size();
int crop_h= param_.crop_h();
int crop_w= param_.crop_w();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}

int h_off = 0;
int w_off = 0;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
if (crop_h > 0 || crop_w > 0) {
CHECK_EQ(crop_h, height);
CHECK_EQ(crop_w, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(input_height - crop_size + 1);
w_off = Rand(input_width - crop_size + 1);
h_off = Rand(input_height - crop_h + 1);
w_off = Rand(input_width - crop_w + 1);
} else {
h_off = (input_height - crop_size) / 2;
w_off = (input_width - crop_size) / 2;
h_off = (input_height - crop_h) / 2;
w_off = (input_width - crop_w) / 2;
}
} else {
CHECK_EQ(input_height, height);
Expand Down Expand Up @@ -397,8 +421,10 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_crop = param_.crop_size() > 0 ||
param_.crop_h() > 0 || param_.crop_w() > 0;
const bool needs_rand = param_.mirror() ||
(phase_ == TRAIN && param_.crop_size());
(phase_ == TRAIN && needs_crop);
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
Expand Down
17 changes: 12 additions & 5 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
LOG(INFO) << "Decoding Datum";
}
// image
int crop_size = this->layer_param_.transform_param().crop_size();
const int crop_size = this->layer_param_.transform_param().crop_size();
int crop_h = this->layer_param_.transform_param().crop_h();
int crop_w = this->layer_param_.transform_param().crop_w();
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}
if (crop_h > 0 || crop_w > 0) {
top[0]->Reshape(this->layer_param_.data_param().batch_size(),
datum.channels(), crop_size, crop_size);
datum.channels(), crop_h, crop_w);
this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(),
datum.channels(), crop_size, crop_size);
this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size);
datum.channels(), crop_h, crop_w);
this->transformed_data_.Reshape(1, datum.channels(), crop_h, crop_w);
} else {
top[0]->Reshape(
this->layer_param_.data_param().batch_size(), datum.channels(),
Expand Down Expand Up @@ -89,8 +94,10 @@ void DataLayer<Dtype>::InternalThreadEntry() {
// Reshape on single input batches for inputs of varying dimension.
const int batch_size = this->layer_param_.data_param().batch_size();
const int crop_size = this->layer_param_.transform_param().crop_size();
const int crop_h = this->layer_param_.transform_param().crop_h();
const int crop_w = this->layer_param_.transform_param().crop_w();
bool force_color = this->layer_param_.data_param().force_encoded_color();
if (batch_size == 1 && crop_size == 0) {
if (batch_size == 1 && crop_size == 0 && crop_h == 0 && crop_w == 0) {
Datum datum;
datum.ParseFromString(cursor_->value());
if (datum.encoded()) {
Expand Down
16 changes: 12 additions & 4 deletions src/caffe/layers/image_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const int width = cv_img.cols;
// image
const int crop_size = this->layer_param_.transform_param().crop_size();
int crop_h = this->layer_param_.transform_param().crop_h();
int crop_w = this->layer_param_.transform_param().crop_w();
const int batch_size = this->layer_param_.image_data_param().batch_size();
if (crop_size > 0) {
top[0]->Reshape(batch_size, channels, crop_size, crop_size);
this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
this->transformed_data_.Reshape(1, channels, crop_size, crop_size);
crop_h = crop_w = crop_size;
}
if (crop_h > 0 || crop_w > 0) {
top[0]->Reshape(batch_size, channels, crop_h, crop_w);
this->prefetch_data_.Reshape(batch_size, channels, crop_h, crop_w);
this->transformed_data_.Reshape(1, channels, crop_h, crop_w);
} else {
top[0]->Reshape(batch_size, channels, height, width);
this->prefetch_data_.Reshape(batch_size, channels, height, width);
Expand Down Expand Up @@ -107,11 +112,14 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
const int new_height = image_data_param.new_height();
const int new_width = image_data_param.new_width();
const int crop_size = this->layer_param_.transform_param().crop_size();
const int crop_h = this->layer_param_.transform_param().crop_h();
const int crop_w = this->layer_param_.transform_param().crop_w();
const bool needs_crop = crop_size != 0 || crop_h != 0 || crop_w != 0;
const bool is_color = image_data_param.is_color();
string root_folder = image_data_param.root_folder();

// Reshape on single input batches for inputs of varying dimension.
if (batch_size == 1 && crop_size == 0 && new_height == 0 && new_width == 0) {
if (batch_size == 1 && !needs_crop && new_height == 0 && new_width == 0) {
cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
0, 0, is_color);
this->prefetch_data_.Reshape(1, cv_img.channels(),
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ message TransformationParameter {
optional bool mirror = 2 [default = false];
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 3 [default = 0];
optional uint32 crop_h = 6 [default = 0]; // The crop height
optional uint32 crop_w = 7 [default = 0]; // The crop width
// mean_file and mean_value cannot be specified at the same time
optional string mean_file = 4;
// if specified can be repeated once (would substract it from all the channels)
Expand Down
69 changes: 57 additions & 12 deletions src/caffe/test/test_data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ class DataTransformTest : public ::testing::Test {
DataTransformer<Dtype>* transformer =
new DataTransformer<Dtype>(transform_param, phase);
const int crop_size = transform_param.crop_size();
int crop_h = transform_param.crop_h();
int crop_w = transform_param.crop_w();
if (crop_size > 0) {
crop_h = crop_w = crop_size;
}
Caffe::set_random_seed(seed_);
transformer->InitRand();
Blob<Dtype>* blob =
new Blob<Dtype>(1, datum.channels(), datum.height(), datum.width());
if (transform_param.crop_size() > 0) {
blob->Reshape(1, datum.channels(), crop_size, crop_size);
if (crop_h > 0 || crop_w > 0) {
blob->Reshape(1, datum.channels(), crop_h, crop_w);
}

vector<vector<Dtype> > crop_sequence;
Expand Down Expand Up @@ -157,17 +162,50 @@ TYPED_TEST(DataTransformTest, TestCropSize) {
}
}

TYPED_TEST(DataTransformTest, TestCrop) {
TransformationParameter transform_param;
const bool unique_pixels = false; // all pixels the same equal to label
const int label = 0;
const int channels = 3;
const int height = 4;
const int width = 5;
const int crop_h = 3;
const int crop_w = 2;

transform_param.set_crop_h(crop_h);
transform_param.set_crop_w(crop_w);
Datum datum;
FillDatum(label, channels, height, width, unique_pixels, &datum);
DataTransformer<TypeParam>* transformer =
new DataTransformer<TypeParam>(transform_param, TEST);
transformer->InitRand();
Blob<TypeParam>* blob =
new Blob<TypeParam>(1, channels, crop_h, crop_w);
for (int iter = 0; iter < this->num_iter_; ++iter) {
transformer->Transform(datum, blob);
EXPECT_EQ(blob->num(), 1);
EXPECT_EQ(blob->channels(), datum.channels());
EXPECT_EQ(blob->height(), crop_h);
EXPECT_EQ(blob->width(), crop_w);
for (int j = 0; j < blob->count(); ++j) {
EXPECT_EQ(blob->cpu_data()[j], label);
}
}
}

TYPED_TEST(DataTransformTest, TestCropTrain) {
TransformationParameter transform_param;
const bool unique_pixels = true; // pixels are consecutive ints [0,size]
const int label = 0;
const int channels = 3;
const int height = 4;
const int width = 5;
const int crop_size = 2;
const int size = channels * crop_size * crop_size;
const int crop_h = 3;
const int crop_w = 2;
const int size = channels * crop_h * crop_w;

transform_param.set_crop_size(crop_size);
transform_param.set_crop_h(crop_h);
transform_param.set_crop_w(crop_w);
Datum datum;
FillDatum(label, channels, height, width, unique_pixels, &datum);
int num_matches = this->NumSequenceMatches(transform_param, datum, TRAIN);
Expand All @@ -181,16 +219,19 @@ TYPED_TEST(DataTransformTest, TestCropTest) {
const int channels = 3;
const int height = 4;
const int width = 5;
const int crop_size = 2;
const int size = channels * crop_size * crop_size;
const int crop_h = 3;
const int crop_w = 2;
const int size = channels * crop_h * crop_w;

transform_param.set_crop_size(crop_size);
transform_param.set_crop_h(crop_h);
transform_param.set_crop_w(crop_w);
Datum datum;
FillDatum(label, channels, height, width, unique_pixels, &datum);
int num_matches = this->NumSequenceMatches(transform_param, datum, TEST);
EXPECT_EQ(num_matches, size * this->num_iter_);
}


TYPED_TEST(DataTransformTest, TestMirrorTrain) {
TransformationParameter transform_param;
const bool unique_pixels = true; // pixels are consecutive ints [0,size]
Expand Down Expand Up @@ -230,11 +271,13 @@ TYPED_TEST(DataTransformTest, TestCropMirrorTrain) {
const int channels = 3;
const int height = 4;
const int width = 5;
const int crop_size = 2;
const int crop_h = 3;
const int crop_w = 2;

Datum datum;
FillDatum(label, channels, height, width, unique_pixels, &datum);
transform_param.set_crop_size(crop_size);
transform_param.set_crop_h(crop_h);
transform_param.set_crop_w(crop_w);
int num_matches_crop = this->NumSequenceMatches(
transform_param, datum, TRAIN);

Expand All @@ -252,11 +295,13 @@ TYPED_TEST(DataTransformTest, TestCropMirrorTest) {
const int channels = 3;
const int height = 4;
const int width = 5;
const int crop_size = 2;
const int crop_h = 3;
const int crop_w = 2;

Datum datum;
FillDatum(label, channels, height, width, unique_pixels, &datum);
transform_param.set_crop_size(crop_size);
transform_param.set_crop_h(crop_h);
transform_param.set_crop_w(crop_w);
int num_matches_crop = this->NumSequenceMatches(transform_param, datum, TEST);

transform_param.set_mirror(true);
Expand Down