Skip to content

Commit

Permalink
Merge pull request #28 from kloudkl/boost-eigen
Browse files Browse the repository at this point in the history
Replace MKL with Boost+Eigen3

* commit '70c4320e436f92d0963b2622d20c7435b2f07f30':
  Fix test_data_layer segfault by adding destructor to join pthread
  Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
  Fix test stochastic pooling stepsize/threshold to be same as max pooling
  Fixed FlattenLayer Backward_cpu/gpu have no return value
  Fixed uniform distribution upper bound to be inclusive
  Add python scripts to install dependent development libs

* commit '9a7d022652d65f44bebc97576a3b4f1b5e559748':
  Fix test_data_layer segfault by adding destructor to join pthread
  Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
  Fix test stochastic pooling stepsize/threshold to be same as max pooling
  Fixed FlattenLayer Backward_cpu/gpu have no return value
  Fixed uniform distribution upper bound to be inclusive

* commit '958f038e9e0b1b1c0c62b9119b323f4d62a3832a':
  Fix test_data_layer segfault by adding destructor to join pthread
  Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
  Fix test stochastic pooling stepsize/threshold to be same as max pooling
  Fixed FlattenLayer Backward_cpu/gpu have no return value
  Fixed uniform distribution upper bound to be inclusive
  • Loading branch information
shelhamer committed Jan 22, 2014
2 parents 5385b74 + 958f038 commit 8d894f0
Show file tree
Hide file tree
Showing 13 changed files with 510 additions and 131 deletions.
8 changes: 8 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff
Expand Down
3 changes: 3 additions & 0 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);

template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);

Expand Down
1 change: 1 addition & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class DataLayer : public Layer<Dtype> {
public:
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~DataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

Expand Down
10 changes: 10 additions & 0 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ namespace caffe {

template <typename Dtype>
void* DataLayerPrefetch(void* layer_pointer) {
CHECK(layer_pointer);
DataLayer<Dtype>* layer = reinterpret_cast<DataLayer<Dtype>*>(layer_pointer);
CHECK(layer);
Datum datum;
CHECK(layer->prefetch_data_);
Dtype* top_data = layer->prefetch_data_->mutable_cpu_data();
Dtype* top_label = layer->prefetch_label_->mutable_cpu_data();
const Dtype scale = layer->layer_param_.scale();
Expand All @@ -38,6 +41,8 @@ void* DataLayerPrefetch(void* layer_pointer) {
const Dtype* mean = layer->data_mean_.cpu_data();
for (int itemid = 0; itemid < batchsize; ++itemid) {
// get a blob
CHECK(layer->iter_);
CHECK(layer->iter_->Valid());
datum.ParseFromString(layer->iter_->value().ToString());
const string& data = datum.data();
if (cropsize) {
Expand Down Expand Up @@ -109,6 +114,11 @@ void* DataLayerPrefetch(void* layer_pointer) {
return (void*)NULL;
}

template <typename Dtype>
DataLayer<Dtype>::~DataLayer<Dtype>() {
// Finally, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
}

template <typename Dtype>
void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Dtype FlattenLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
caffe_copy(count_, top_diff, bottom_diff);
return Dtype(0);
}


Expand All @@ -52,6 +53,7 @@ Dtype FlattenLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
caffe_gpu_copy(count_, top_diff, bottom_diff);
return Dtype(0);
}

INSTANTIATE_CLASS(FlattenLayer);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/test/test_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ TYPED_TEST(DataLayerTest, TestRead) {
EXPECT_EQ(this->blob_top_label_->channels(), 1);
EXPECT_EQ(this->blob_top_label_->height(), 1);
EXPECT_EQ(this->blob_top_label_->width(), 1);
// Go throught the data twice
for (int iter = 0; iter < 2; ++iter) {
// Go through the data 100 times
for (int iter = 0; iter < 100; ++iter) {
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]);
Expand Down
3 changes: 3 additions & 0 deletions src/caffe/test/test_flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FlattenLayerTest : public ::testing::Test {
FlattenLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
Expand Down Expand Up @@ -72,6 +73,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) {
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0),
this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

Expand Down
21 changes: 13 additions & 8 deletions src/caffe/test/test_gradient_check_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
blobs_to_check.push_back(bottom[check_bottom]);
}
// go through the bottom and parameter blobs
// LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs.";
// LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs.";
for (int blobid = 0; blobid < blobs_to_check.size(); ++blobid) {
Blob<Dtype>* current_blob = blobs_to_check[blobid];
// LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count()
// << " parameters.";
// LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count()
// << " parameters.";
// go through the values
for (int feat_id = 0; feat_id < current_blob->count(); ++feat_id) {
// First, obtain the original data
Expand All @@ -96,25 +96,28 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
// Get any additional loss from the layer
computed_objective += layer.Backward(top, true, &bottom);
Dtype computed_gradient = current_blob->cpu_diff()[feat_id];

// compute score by adding stepsize
current_blob->mutable_cpu_data()[feat_id] += stepsize_;
Caffe::set_random_seed(seed_);
layer.Forward(bottom, &top);
Dtype positive_objective = GetObjAndGradient(top, top_id, top_data_id);
positive_objective += layer.Backward(top, true, &bottom);

// compute score by subtracting stepsize
current_blob->mutable_cpu_data()[feat_id] -= stepsize_ * 2;
Caffe::set_random_seed(seed_);
layer.Forward(bottom, &top);
Dtype negative_objective = GetObjAndGradient(top, top_id, top_data_id);
negative_objective += layer.Backward(top, true, &bottom);

// Recover stepsize
current_blob->mutable_cpu_data()[feat_id] += stepsize_;
Dtype estimated_gradient = (positive_objective - negative_objective) /
stepsize_ / 2.;
Dtype feature = current_blob->cpu_data()[feat_id];
// LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " "
// << current_blob->cpu_diff()[feat_id];
// LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " "
// << current_blob->cpu_diff()[feat_id];
if (kink_ - kink_range_ > feature || feature > kink_ + kink_range_) {
// We check relative accuracy, but for too small values, we threshold
// the scale factor by 1.
Expand All @@ -126,10 +129,12 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>& layer,
EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale)
<< "debug: (top_id, top_data_id, blob_id, feat_id)="
<< top_id << "," << top_data_id << "," << blobid << "," << feat_id;
// LOG(ERROR) << "computed gradient: " << computed_gradient
// << " estimated_gradient: " << estimated_gradient
// << " positive_objective: " << positive_objective
// << " negative_objective: " << negative_objective;
}
// LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id];
// LOG(ERROR) << "computed gradient: " << computed_gradient
// << " estimated_gradient: " << estimated_gradient;
// LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]
}
}
}
Expand Down
194 changes: 194 additions & 0 deletions src/caffe/test/test_math_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright 2013 Yangqing Jia

#include <cmath>
#include <cstring>

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/util/math_functions.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

template <typename Dtype>
class MathFunctionsTest : public ::testing::Test {
protected:
MathFunctionsTest()
: loops_(10)
,a_(new Blob<Dtype>(2, 3, 6, 5))
,b_(new Blob<Dtype>(2, 3, 6, 5))
,y_(new Blob<Dtype>(2, 3, 6, 5))
,a_cpu_data_(a_->cpu_data())
,b_cpu_data_(b_->cpu_data())
,y_cpu_data_(y_->mutable_cpu_data())
,near_delta_(1e-5)
{};

virtual void SetUp() {
num_ = a_->count();
filler_param_.set_min(1e-5);
filler_param_.set_max(10);
};

virtual ~MathFunctionsTest() {
delete a_;
delete b_;
delete y_;
}

int loops_;
int num_;
Blob<Dtype>* a_;
Blob<Dtype>* b_;
Blob<Dtype>* y_;
const Dtype* const a_cpu_data_;
const Dtype* const b_cpu_data_;
Dtype* y_cpu_data_;
const Dtype near_delta_;
FillerParameter filler_param_;
};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(MathFunctionsTest, Dtypes);

TYPED_TEST(MathFunctionsTest, TestAdd) {
GaussianFiller<TypeParam> filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
filler.Fill(this->b_);
caffe_add(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] + this->b_cpu_data_[i], this->near_delta_);
}
}
}

TYPED_TEST(MathFunctionsTest, TestSub) {
GaussianFiller<TypeParam> filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
filler.Fill(this->b_);
caffe_sub(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] - this->b_cpu_data_[i], this->near_delta_);
}
}
}

TYPED_TEST(MathFunctionsTest, TestMul) {
GaussianFiller<TypeParam> filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
filler.Fill(this->b_);
caffe_mul(this->num_, this->a_->cpu_data(), this->b_->cpu_data(), this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->b_cpu_data_[i], this->near_delta_);
}
}
}

TYPED_TEST(MathFunctionsTest, TestDiv) {
GaussianFiller<TypeParam> filler(this->filler_param_);
UniformFiller<TypeParam> uniform_filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
filler.Fill(this->b_);
FillerParameter filler_param;
filler_param.set_min(1e-5); // to avoid dividing by zero
uniform_filler.Fill(this->b_);
caffe_div(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] /
this->b_cpu_data_[i], this->near_delta_);
}
}
}

TYPED_TEST(MathFunctionsTest, TestPowx) {
GaussianFiller<TypeParam> filler(this->filler_param_);
UniformFiller<TypeParam> uniform_filler(this->filler_param_);
TypeParam p;
for (int l = 0; l < this->loops_; ++l) {
p = 0;
filler.Fill(this->a_);
caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) ,
this->near_delta_)
<< "debug: (i, y_cpu_data_, a_cpu_data_, p)="
<< i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i]
<< "," << p;
}

p = 0.5;
uniform_filler.Fill(this->a_);
caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) ,
this->near_delta_)
<< "debug: (i, y_cpu_data_, a_cpu_data_, p)="
<< i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i]
<< "," << p;
}

p = -0.5;
uniform_filler.Fill(this->a_);
caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) ,
this->near_delta_)
<< "debug: (i, y_cpu_data_, a_cpu_data_, p)="
<< i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i]
<< "," << p;
}

p = 1.5;
uniform_filler.Fill(this->a_);
caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) ,
this->near_delta_)
<< "debug: (i, y_cpu_data_, a_cpu_data_, p)="
<< i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i]
<< "," << p;
}

p = -1.5;
uniform_filler.Fill(this->a_);
caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) ,
this->near_delta_)
<< "debug: (i, y_cpu_data_, a_cpu_data_, p)="
<< i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i]
<< "," << p;
}
}
}

TYPED_TEST(MathFunctionsTest, TestSqr) {
GaussianFiller<TypeParam> filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
caffe_sqr(this->num_, this->a_cpu_data_, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->a_cpu_data_[i], this->near_delta_);
}
}
}

TYPED_TEST(MathFunctionsTest, TestExp) {
GaussianFiller<TypeParam> filler(this->filler_param_);
for (int l = 0; l < this->loops_; ++l) {
filler.Fill(this->a_);
caffe_exp(this->num_, this->a_cpu_data_, this->y_cpu_data_);
for (int i = 0; i < this->num_; ++i) {
EXPECT_NEAR(this->y_cpu_data_[i], std::exp(this->a_cpu_data_[i]), this->near_delta_);
}
}
}

} // namespace caffe
1 change: 1 addition & 0 deletions src/caffe/test/test_multinomial_logistic_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test {
MultinomialLogisticLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
PositiveUnitballFiller<Dtype> filler(filler_param);
Expand Down
Loading

0 comments on commit 8d894f0

Please sign in to comment.