diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index f31d3b0f693..75cc3c67288 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -27,6 +27,14 @@ class Blob { inline int count() const {return count_; } inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0) const { + CHECK_GE(n, 0); + CHECK_LE(n, num_); + CHECK_GE(channels_, 0); + CHECK_LE(c, channels_); + CHECK_GE(height_, 0); + CHECK_LE(h, height_); + CHECK_GE(width_, 0); + CHECK_LE(w, width_); return ((n * channels_ + c) * height_ + h) * width_ + w; } // Copy from source. If copy_diff is false, we copy the data; if copy_diff diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 05ed1301599..a4c63de0440 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -85,6 +85,9 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y); template void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y); +template +Dtype caffe_nextafter(const Dtype b); + template void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b); diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 2c23b456535..fd84866ce54 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -294,6 +294,7 @@ class DataLayer : public Layer { public: explicit DataLayer(const LayerParameter& param) : Layer(param) {} + virtual ~DataLayer(); virtual void SetUp(const vector*>& bottom, vector*>* top); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 078d49708b6..d1262d03f24 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -17,8 +17,11 @@ namespace caffe { template void* DataLayerPrefetch(void* layer_pointer) { + CHECK(layer_pointer); DataLayer* layer = reinterpret_cast*>(layer_pointer); + CHECK(layer); Datum datum; + CHECK(layer->prefetch_data_); Dtype* top_data = layer->prefetch_data_->mutable_cpu_data(); Dtype* top_label = layer->prefetch_label_->mutable_cpu_data(); const Dtype scale = layer->layer_param_.scale(); @@ -38,6 +41,8 @@ void* DataLayerPrefetch(void* layer_pointer) { const Dtype* mean = layer->data_mean_.cpu_data(); for (int itemid = 0; itemid < batchsize; ++itemid) { // get a blob + CHECK(layer->iter_); + CHECK(layer->iter_->Valid()); datum.ParseFromString(layer->iter_->value().ToString()); const string& data = datum.data(); if (cropsize) { @@ -109,6 +114,11 @@ void* DataLayerPrefetch(void* layer_pointer) { return (void*)NULL; } +template +DataLayer::~DataLayer() { + // Finally, join the thread + CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed."; +} template void DataLayer::SetUp(const vector*>& bottom, diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp index f2467444809..f4ca6d0607f 100644 --- a/src/caffe/layers/flatten_layer.cpp +++ b/src/caffe/layers/flatten_layer.cpp @@ -43,6 +43,7 @@ Dtype FlattenLayer::Backward_cpu(const vector*>& top, const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); caffe_copy(count_, top_diff, bottom_diff); + return Dtype(0); } @@ -52,6 +53,7 @@ Dtype FlattenLayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); caffe_gpu_copy(count_, top_diff, bottom_diff); + return Dtype(0); } INSTANTIATE_CLASS(FlattenLayer); diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index fe3e915b5aa..66e9956838b 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -81,8 +81,8 @@ TYPED_TEST(DataLayerTest, TestRead) { EXPECT_EQ(this->blob_top_label_->channels(), 1); EXPECT_EQ(this->blob_top_label_->height(), 1); EXPECT_EQ(this->blob_top_label_->width(), 1); - // Go throught the data twice - for (int iter = 0; iter < 2; ++iter) { + // Go through the data 100 times + for (int iter = 0; iter < 100; ++iter) { layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_); for (int i = 0; i < 5; ++i) { EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]); diff --git a/src/caffe/test/test_flatten_layer.cpp b/src/caffe/test/test_flatten_layer.cpp index 805fd72eb5b..bb345d93302 100644 --- a/src/caffe/test/test_flatten_layer.cpp +++ b/src/caffe/test/test_flatten_layer.cpp @@ -22,6 +22,7 @@ class FlattenLayerTest : public ::testing::Test { FlattenLayerTest() : blob_bottom_(new Blob(2, 3, 6, 5)), blob_top_(new Blob()) { + Caffe::set_random_seed(1701); // fill the values FillerParameter filler_param; GaussianFiller filler(filler_param); @@ -72,6 +73,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) { for (int c = 0; c < 3 * 6 * 5; ++c) { EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0), this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5)); + EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0), + this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5)); } } diff --git a/src/caffe/test/test_gradient_check_util.hpp b/src/caffe/test/test_gradient_check_util.hpp index d7360085d40..85edd05b693 100644 --- a/src/caffe/test/test_gradient_check_util.hpp +++ b/src/caffe/test/test_gradient_check_util.hpp @@ -82,11 +82,11 @@ void GradientChecker::CheckGradientSingle(Layer& layer, blobs_to_check.push_back(bottom[check_bottom]); } // go through the bottom and parameter blobs - // LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs."; +// LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs."; for (int blobid = 0; blobid < blobs_to_check.size(); ++blobid) { Blob* current_blob = blobs_to_check[blobid]; - // LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count() - // << " parameters."; +// LOG(ERROR) << "Blob " << blobid << ": checking " << current_blob->count() +// << " parameters."; // go through the values for (int feat_id = 0; feat_id < current_blob->count(); ++feat_id) { // First, obtain the original data @@ -96,25 +96,28 @@ void GradientChecker::CheckGradientSingle(Layer& layer, // Get any additional loss from the layer computed_objective += layer.Backward(top, true, &bottom); Dtype computed_gradient = current_blob->cpu_diff()[feat_id]; + // compute score by adding stepsize current_blob->mutable_cpu_data()[feat_id] += stepsize_; Caffe::set_random_seed(seed_); layer.Forward(bottom, &top); Dtype positive_objective = GetObjAndGradient(top, top_id, top_data_id); positive_objective += layer.Backward(top, true, &bottom); + // compute score by subtracting stepsize current_blob->mutable_cpu_data()[feat_id] -= stepsize_ * 2; Caffe::set_random_seed(seed_); layer.Forward(bottom, &top); Dtype negative_objective = GetObjAndGradient(top, top_id, top_data_id); negative_objective += layer.Backward(top, true, &bottom); + // Recover stepsize current_blob->mutable_cpu_data()[feat_id] += stepsize_; Dtype estimated_gradient = (positive_objective - negative_objective) / stepsize_ / 2.; Dtype feature = current_blob->cpu_data()[feat_id]; - // LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " - // << current_blob->cpu_diff()[feat_id]; +// LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " +// << current_blob->cpu_diff()[feat_id]; if (kink_ - kink_range_ > feature || feature > kink_ + kink_range_) { // We check relative accuracy, but for too small values, we threshold // the scale factor by 1. @@ -126,10 +129,12 @@ void GradientChecker::CheckGradientSingle(Layer& layer, EXPECT_LT(computed_gradient, estimated_gradient + threshold_ * scale) << "debug: (top_id, top_data_id, blob_id, feat_id)=" << top_id << "," << top_data_id << "," << blobid << "," << feat_id; +// LOG(ERROR) << "computed gradient: " << computed_gradient +// << " estimated_gradient: " << estimated_gradient +// << " positive_objective: " << positive_objective +// << " negative_objective: " << negative_objective; } - // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]; - // LOG(ERROR) << "computed gradient: " << computed_gradient - // << " estimated_gradient: " << estimated_gradient; + // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id] } } } diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp new file mode 100644 index 00000000000..31973cc8619 --- /dev/null +++ b/src/caffe/test/test_math_functions.cpp @@ -0,0 +1,194 @@ +// Copyright 2013 Yangqing Jia + +#include +#include + +#include "gtest/gtest.h" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MathFunctionsTest : public ::testing::Test { + protected: + MathFunctionsTest() + : loops_(10) + ,a_(new Blob(2, 3, 6, 5)) + ,b_(new Blob(2, 3, 6, 5)) + ,y_(new Blob(2, 3, 6, 5)) + ,a_cpu_data_(a_->cpu_data()) + ,b_cpu_data_(b_->cpu_data()) + ,y_cpu_data_(y_->mutable_cpu_data()) + ,near_delta_(1e-5) + {}; + + virtual void SetUp() { + num_ = a_->count(); + filler_param_.set_min(1e-5); + filler_param_.set_max(10); + }; + + virtual ~MathFunctionsTest() { + delete a_; + delete b_; + delete y_; + } + + int loops_; + int num_; + Blob* a_; + Blob* b_; + Blob* y_; + const Dtype* const a_cpu_data_; + const Dtype* const b_cpu_data_; + Dtype* y_cpu_data_; + const Dtype near_delta_; + FillerParameter filler_param_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(MathFunctionsTest, Dtypes); + +TYPED_TEST(MathFunctionsTest, TestAdd) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_add(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] + this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestSub) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_sub(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] - this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestMul) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + caffe_mul(this->num_, this->a_->cpu_data(), this->b_->cpu_data(), this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestDiv) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + filler.Fill(this->b_); + FillerParameter filler_param; + filler_param.set_min(1e-5); // to avoid dividing by zero + uniform_filler.Fill(this->b_); + caffe_div(this->num_, this->a_cpu_data_, this->b_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] / + this->b_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestPowx) { + GaussianFiller filler(this->filler_param_); + UniformFiller uniform_filler(this->filler_param_); + TypeParam p; + for (int l = 0; l < this->loops_; ++l) { + p = 0; + filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + + p = 0.5; + uniform_filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + + p = -0.5; + uniform_filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + + p = 1.5; + uniform_filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + + p = -1.5; + uniform_filler.Fill(this->a_); + caffe_powx(this->num_, this->a_cpu_data_, p, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::pow(this->a_cpu_data_[i], p) , + this->near_delta_) + << "debug: (i, y_cpu_data_, a_cpu_data_, p)=" + << i << "," << this->y_cpu_data_[i] << "," << this->a_cpu_data_[i] + << "," << p; + } + } +} + +TYPED_TEST(MathFunctionsTest, TestSqr) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + caffe_sqr(this->num_, this->a_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], this->a_cpu_data_[i] * this->a_cpu_data_[i], this->near_delta_); + } + } +} + +TYPED_TEST(MathFunctionsTest, TestExp) { + GaussianFiller filler(this->filler_param_); + for (int l = 0; l < this->loops_; ++l) { + filler.Fill(this->a_); + caffe_exp(this->num_, this->a_cpu_data_, this->y_cpu_data_); + for (int i = 0; i < this->num_; ++i) { + EXPECT_NEAR(this->y_cpu_data_[i], std::exp(this->a_cpu_data_[i]), this->near_delta_); + } + } +} + +} // namespace caffe diff --git a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp index adb36627606..6bd94ae24b8 100644 --- a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp +++ b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp @@ -24,6 +24,7 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test { MultinomialLogisticLossLayerTest() : blob_bottom_data_(new Blob(10, 5, 1, 1)), blob_bottom_label_(new Blob(10, 1, 1, 1)) { + Caffe::set_random_seed(1701); // fill the values FillerParameter filler_param; PositiveUnitballFiller filler(filler_param); diff --git a/src/caffe/test/test_random_number_generator.cpp b/src/caffe/test/test_random_number_generator.cpp new file mode 100644 index 00000000000..4c3358f9f49 --- /dev/null +++ b/src/caffe/test/test_random_number_generator.cpp @@ -0,0 +1,67 @@ +#include +#include +#include + +#include "gtest/gtest.h" +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class RandomNumberGeneratorTest : public ::testing::Test { + public: + virtual ~RandomNumberGeneratorTest() {} + + Dtype sample_mean(const Dtype* const seqs, const size_t sample_size) + { + double sum = 0; + for (int i = 0; i < sample_size; ++i) { + sum += seqs[i]; + } + return sum / sample_size; + } + + Dtype mean_bound(const Dtype std, const size_t sample_size) + { + return std/sqrt((double)sample_size); + } +}; + + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(RandomNumberGeneratorTest, Dtypes); + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian) { + size_t sample_size = 10000; + SyncedMemory data_a(sample_size * sizeof(TypeParam)); + Caffe::set_random_seed(1701); + TypeParam mu = 0; + TypeParam sigma = 1; + caffe_vRngGaussian(sample_size, (TypeParam*)data_a.mutable_cpu_data(), mu, sigma); + TypeParam true_mean = mu; + TypeParam true_std = sigma; + TypeParam bound = mean_bound(true_std, sample_size); + TypeParam real_mean = sample_mean((TypeParam*)data_a.cpu_data(), sample_size); + EXPECT_NEAR(real_mean, true_mean, bound); +} + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform) { + size_t sample_size = 10000; + SyncedMemory data_a(sample_size * sizeof(TypeParam)); + Caffe::set_random_seed(1701); + TypeParam lower = 0; + TypeParam upper = 1; + caffe_vRngUniform(sample_size, (TypeParam*)data_a.mutable_cpu_data(), lower, upper); + TypeParam true_mean = (lower + upper) / 2; + TypeParam true_std = (upper - lower) / sqrt(12); + TypeParam bound = mean_bound(true_std, sample_size); + TypeParam real_mean = sample_mean((TypeParam*)data_a.cpu_data(), sample_size); + EXPECT_NEAR(real_mean, true_mean, bound); +} + + + +} // namespace caffe diff --git a/src/caffe/test/test_stochastic_pooing.cpp b/src/caffe/test/test_stochastic_pooing.cpp index e2b60eeec34..b8b07cb5999 100644 --- a/src/caffe/test/test_stochastic_pooing.cpp +++ b/src/caffe/test/test_stochastic_pooing.cpp @@ -140,8 +140,6 @@ TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) { } } - - TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) { Caffe::set_mode(Caffe::GPU); Caffe::set_phase(Caffe::TRAIN); @@ -151,12 +149,10 @@ TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) { layer_param.set_pool(LayerParameter_PoolMethod_STOCHASTIC); PoolingLayer layer(layer_param); - GradientChecker checker(1e-2, 1e-3); + GradientChecker checker(1e-4, 1e-2); // it is too expensive to call curand multiple times, so we don't do an // exhaustive gradient check. checker.CheckGradient(layer, this->blob_bottom_vec_, this->blob_top_vec_); } - - } diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 25de4251a00..f292972e4e2 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -1,7 +1,9 @@ // Copyright 2013 Yangqing Jia +#include //#include #include +#include #include #include @@ -10,11 +12,22 @@ namespace caffe { -const int data_alignment = Eigen::Aligned; // how is data allocated ? -typedef Eigen::Map const_map_vector_float_t; -typedef Eigen::Map map_vector_float_t; -typedef Eigen::Map const_map_vector_double_t; -typedef Eigen::Map map_vector_double_t; +// Operations on aligned memory are faster than on unaligned memory. +// But unfortunately, the pointers passed in are not always aligned. +// Therefore, the memory-aligned Eigen::Map objects that wrap them +// cannot be assigned to. This happens in lrn_layer and makes +// test_lrn_layer crash with segmentation fault. +// TODO: Use aligned Eigen::Map when the pointer to be wrapped is aligned. + +// Though the default map option is unaligned, making it explicit is no harm. +//const int data_alignment = Eigen::Aligned; // how is data allocated ? +const int data_alignment = Eigen::Unaligned; +typedef Eigen::Array float_array_t; +typedef Eigen::Map const_map_vector_float_t; +typedef Eigen::Map map_vector_float_t; +typedef Eigen::Array double_array_t; +typedef Eigen::Map const_map_vector_double_t; +typedef Eigen::Map map_vector_double_t; template<> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, @@ -125,25 +138,6 @@ void caffe_gpu_axpy(const int N, const double alpha, const double* X, CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); } -template <> -void caffe_axpby(const int N, const float alpha, const float* X, - const float beta, float* Y) { - // y := a*x + b*y - //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); - map_vector_float_t(Y, N) *= beta; - map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N)); - -} - -template <> -void caffe_axpby(const int N, const double alpha, const double* X, - const double beta, double* Y) { - // y := a*x + b*y - //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); - map_vector_double_t(Y, N) *= beta; - map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N)); -} - template <> void caffe_copy(const int N, const float* X, float* Y) { cblas_scopy(N, X, 1, Y, 1); @@ -199,180 +193,275 @@ void caffe_gpu_axpby(const int N, const double alpha, const double* X, } template <> -void caffe_sqr(const int n, const float* a, float* y) { - //vsSqr(n, a, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt(); +void caffe_axpby(const int N, const float alpha, const float* X, + const float beta, float* Y) { + // y := a*x + b*y + //cblas_saxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_float_t y_map(Y, N); + // Eigen produces optimized code using lasy evaluation + // http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html + y_map = const_map_vector_float_t(X, N) * alpha + y_map * beta; } template <> -void caffe_sqr(const int n, const double* a, double* y) { - //vdSqr(n, a, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt(); +void caffe_axpby(const int N, const double alpha, const double* X, + const double beta, double* Y) { + // y := a*x + b*y + //cblas_daxpby(N, alpha, X, 1, beta, Y, 1); + CHECK_GE(N, 0); + CHECK(X); + CHECK(Y); + map_vector_double_t y_map(Y, N); + y_map = const_map_vector_double_t(X, N) * alpha + y_map * beta; } template <> void caffe_add(const int n, const float* a, const float* b, float* y) { - //vsAdd(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n); + //vsAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + + const_map_vector_float_t(b, n); } template <> void caffe_add(const int n, const double* a, const double* b, double* y) { - //vdAdd(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n); + //vdAdd(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + + const_map_vector_double_t(b, n); } template <> void caffe_sub(const int n, const float* a, const float* b, float* y) { - //vsSub(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n); + //vsSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - + const_map_vector_float_t(b, n); } template <> void caffe_sub(const int n, const double* a, const double* b, double* y) { - //vdSub(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n); + //vdSub(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - + const_map_vector_double_t(b, n); } template <> void caffe_mul(const int n, const float* a, const float* b, float* y) { - //vsMul(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array(); + //vsMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) * + const_map_vector_float_t(b, n); } template <> void caffe_mul(const int n, const double* a, const double* b, double* y) { - //vdMul(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array(); + //vdMul(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) * + const_map_vector_double_t(b, n); } template <> void caffe_div(const int n, const float* a, const float* b, float* y) { - //vsDiv(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array(); + //vsDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n) / + const_map_vector_float_t(b, n); } template <> void caffe_div(const int n, const double* a, const double* b, double* y) { - //vdDiv(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array(); + //vdDiv(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(b); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n) / + const_map_vector_double_t(b, n); } template <> void caffe_powx(const int n, const float* a, const float b, float* y) { - //vsPowx(n, a, b, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b); + //vsPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).pow(b); } template <> void caffe_powx(const int n, const double* a, const double b, double* y) { - //vdPowx(n, a, b, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b); + //vdPowx(n, a, b, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).pow(b); } template <> -void caffe_vRngUniform(const int n, float* r, - const float a, const float b) { +void caffe_sqr(const int n, const float* a, float* y) { + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-F003F826-81BF-42EC-AE51-2EF624893133.htm + // v?Sqr Performs element by element squaring of the vector. + //vsSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx(n, a, 2, y); + // TODO: which is faster? +// map_vector_float_t(y, n) = const_map_vector_float_t(a, n) * +// const_map_vector_float_t(a, n); +} + +template <> +void caffe_sqr(const int n, const double* a, double* y) { + //vdSqr(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + caffe_powx(n, a, 2, y); +} + +template <> +void caffe_exp(const int n, const float* a, float* y) { + //vsExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_float_t(y, n) = const_map_vector_float_t(a, n).exp(); +} + +template <> +void caffe_exp(const int n, const double* a, double* y) { + //vdExp(n, a, y); + CHECK_GE(n, 0); + CHECK(a); + CHECK(y); + map_vector_double_t(y, n) = const_map_vector_double_t(a, n).exp(); +} + +template +Dtype caffe_nextafter(const Dtype b) { + return boost::math::nextafter( + b, std::numeric_limits::max()); +} + +template +float caffe_nextafter(const float b); + +template +double caffe_nextafter(const double b); + +template +void caffe_vRngUniform(const int n, Dtype* r, + const Dtype a, const Dtype b) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), // n, r, a, b)); // FIXME check if boundaries are handled in the same way ? - boost::uniform_real random_distribution(a, b); + // Fixed by caffe_nextafter + boost::random::uniform_real_distribution random_distribution( + a, caffe_nextafter(b)); Caffe::random_generator_t &generator = Caffe::vsl_stream(); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); } } -template <> +template +void caffe_vRngUniform(const int n, float* r, + const float a, const float b); +template void caffe_vRngUniform(const int n, double* r, - const double a, const double b) { - //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(), - // n, r, a, b)); - - // FIXME check if boundaries are handled in the same way ? - boost::uniform_real random_distribution(a, b); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + const double a, const double b); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - -template <> -void caffe_vRngGaussian(const int n, float* r, const float a, - const float sigma) { +template +void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a, + const Dtype sigma) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GT(sigma, 0); //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, // Caffe::vsl_stream(), n, r, a, sigma)); // FIXME check if parameters are handled in the same way ? - boost::normal_distribution random_distribution(a, sigma); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); + // http://www.boost.org/doc/libs/1_55_0/doc/html/boost/random/normal_distribution.html + // http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-63196F25-5013-4038-8BCD-2613C4EF3DE4.htm + // The above two documents show that the probability density functions are different. + // But the unit tests still pass. Maybe their codes are the same or + // the tests are irrelevant to the random numbers. + boost::normal_distribution random_distribution(a, sigma); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); + } } +template +void caffe_vRngGaussian(const int n, float* r, const float a, + const float sigma); -template <> +template void caffe_vRngGaussian(const int n, double* r, const double a, - const double sigma) { - //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER, - // Caffe::vsl_stream(), n, r, a, sigma)); - - // FIXME check if parameters are handled in the same way ? - boost::normal_distribution random_distribution(a, sigma); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); - - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - + const double sigma); template -void caffe_vRngBernoulli(const int n, Dtype* r, const double p) -{ +void caffe_vRngBernoulli(const int n, Dtype* r, const double p) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); // FIXME check if parameters are handled in the same way ? - boost::bernoulli_distribution random_distribution(p); - Caffe::random_generator_t &generator = Caffe::vsl_stream(); - - for(int i = 0; i < n; i += 1) - { - r[i] = random_distribution(generator); - } -} - -template void caffe_vRngBernoulli(const int n, int* r, const double p); - + boost::bernoulli_distribution random_distribution(p); + Caffe::random_generator_t &generator = Caffe::vsl_stream(); -template <> -void caffe_exp(const int n, const float* a, float* y) { - //vsExp(n, a, y); - map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp(); + for(int i = 0; i < n; i += 1) { + r[i] = random_distribution(generator); + } } -template <> -void caffe_exp(const int n, const double* a, double* y) { - //vdExp(n, a, y); - map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp(); -} +template +void caffe_vRngBernoulli(const int n, int* r, const double p); template <> float caffe_cpu_dot(const int n, const float* x, const float* y) {