diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index e9e2db8f274..003d07cd4d3 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -1,8 +1,11 @@ // Copyright 2013 Yangqing Jia +// Copyright 2014 kloudkl@github #ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_ #define CAFFE_UTIL_MATH_FUNCTIONS_H_ +#include // for std::fabs +#include // for signbit #include #include @@ -100,6 +103,84 @@ Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y); template void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out); +template +int caffe_hamming_distance(const int n, const Dtype* x, const Dtype* y); + +// Returns the sum of the absolute values of the elements of vector x +template +Dtype caffe_cpu_asum(const int n, const Dtype* x); + +template +void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y); + +// the branchless, type-safe version from +// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c +template +inline char caffe_sign(Dtype val) { + return (Dtype(0) < val) - (val < Dtype(0)); +} + +// The following two macros are modifications of DEFINE_VSL_UNARY_FUNC +// in include/caffe/util/mkl_alternate.hpp authored by @Rowland Depp. +// Please refer to commit 7e8ef25c7 of the boost-eigen branch. +#define DEFINE_CAFFE_CPU_UNARY_FUNC(name, operation) \ + template \ + void caffe_cpu_##name(const int n, const Dtype* x, Dtype* y) { \ + CHECK_GT(n, 0); CHECK(x); CHECK(y); \ + for (int i = 0; i < n; ++i) { \ + operation; \ + } \ + } + +#define INSTANTIATE_CAFFE_CPU_UNARY_FUNC(name) \ + template <> \ + void caffe_cpu_##name(const int n, const float* x, float* y); \ + template <> \ + void caffe_cpu_##name(const int n, const double* x, double* y) + + +#define DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(name, operation) \ +template \ +__global__ void name##_kernel(const int n, const Dtype* x, Dtype* y) { \ + int index = threadIdx.x + blockIdx.x * blockDim.x; \ + if (index < n) { \ + operation; \ + } \ +} \ +template <> \ +void caffe_gpu_##name(const int n, const float* x, float* y) { \ + name##_kernel<<>>( \ + n, x, y); \ +} \ +template <> \ +void caffe_gpu_##name(const int n, const double* x, double* y) { \ + name##_kernel<<>>( \ + n, x, y); \ +} + +// output is 1 for the positives, 0 for zero, and -1 for the negatives +DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign(x[i])); + +template +void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y); + +// returns a nonzero value is the input has its sign bit set. +DEFINE_CAFFE_CPU_UNARY_FUNC(signbit, y[i] = std::signbit(x[i])); + +template +void caffe_gpu_signbit(const int n, const Dtype* x, Dtype* y); + +DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i])); + +template +void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y); + +template +void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); + +template +void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); + } // namespace caffe diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp new file mode 100644 index 00000000000..d314d73b45c --- /dev/null +++ b/src/caffe/test/test_math_functions.cpp @@ -0,0 +1,190 @@ +// Copyright 2014 kloudkl@github + +#include // for uint32_t & uint64_t +#include // for std::fabs + +#include "gtest/gtest.h" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MathFunctionsTest : public ::testing::Test { + protected: + MathFunctionsTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()) { + } + + virtual void SetUp() { + Caffe::set_random_seed(1701); + this->blob_bottom_->Reshape(100, 70, 50, 30); + this->blob_top_->Reshape(100, 70, 50, 30); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_top_); + } + + virtual ~MathFunctionsTest() { + delete blob_bottom_; + delete blob_top_; + } + // http://en.wikipedia.org/wiki/Hamming_distance + int ReferenceHammingDistance(const int n, const Dtype* x, const Dtype* y); + + Blob* const blob_bottom_; + Blob* const blob_top_; +}; + +#define REF_HAMMING_DIST(float_type, int_type) \ +template<> \ +int MathFunctionsTest::ReferenceHammingDistance(const int n, \ + const float_type* x, \ + const float_type* y) { \ + int dist = 0; \ + int_type val; \ + for (int i = 0; i < n; ++i) { \ + val = static_cast(x[i]) ^ static_cast(y[i]); \ + /* Count the number of set bits */ \ + while (val) { \ + ++dist; \ + val &= val - 1; \ + } \ + } \ + return dist; \ +} + +REF_HAMMING_DIST(float, uint32_t); +REF_HAMMING_DIST(double, uint64_t); + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(MathFunctionsTest, Dtypes); + +TYPED_TEST(MathFunctionsTest, TestHammingDistance){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + const TypeParam* y = this->blob_top_->cpu_data(); + CHECK_EQ(this->ReferenceHammingDistance(n, x, y), + caffe_hamming_distance(n, x, y)); +} + +TYPED_TEST(MathFunctionsTest, TestAsumCPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + TypeParam std_asum = 0; + for (int i = 0; i < n; ++i) { + std_asum += std::fabs(x[i]); + } + TypeParam cpu_asum = caffe_cpu_asum(n, x); + CHECK_LT((cpu_asum - std_asum) / std_asum, 1e-2); +} + +TYPED_TEST(MathFunctionsTest, TestAsumGPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + TypeParam std_asum = 0; + for (int i = 0; i < n; ++i) { + std_asum += std::fabs(x[i]); + } + TypeParam gpu_asum; + caffe_gpu_asum(n, this->blob_bottom_->gpu_data(), &gpu_asum); + CHECK_LT((gpu_asum - std_asum) / std_asum, 1e-2); +} + +TYPED_TEST(MathFunctionsTest, TestSignCPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_sign(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +TYPED_TEST(MathFunctionsTest, TestSignGPU){ + int n = this->blob_bottom_->count(); + caffe_gpu_sign(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +TYPED_TEST(MathFunctionsTest, TestSignbitCPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_signbit(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* signbits = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0); + } +} + +TYPED_TEST(MathFunctionsTest, TestSignbitGPU){ + int n = this->blob_bottom_->count(); + caffe_gpu_signbit(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* signbits = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0); + } +} + +TYPED_TEST(MathFunctionsTest, TestFabsCPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_fabs(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]); + } +} + +TYPED_TEST(MathFunctionsTest, TestFabsGPU){ + int n = this->blob_bottom_->count(); + caffe_gpu_fabs(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]); + } +} + +TYPED_TEST(MathFunctionsTest, TestScaleCPU){ + int n = this->blob_bottom_->count(); + TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() % + this->blob_bottom_->count()]; + caffe_cpu_scale(n, alpha, this->blob_bottom_->cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* scaled = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(scaled[i], x[i] * alpha); + } +} + +TYPED_TEST(MathFunctionsTest, TestScaleGPU){ + int n = this->blob_bottom_->count(); + TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() % + this->blob_bottom_->count()]; + caffe_gpu_scale(n, alpha, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* scaled = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(scaled[i], x[i] * alpha); + } +} + +} diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 60656b87093..8a2f25e0c34 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -1,4 +1,5 @@ // Copyright 2013 Yangqing Jia +// Copyright 2014 kloudkl@github #include #include @@ -293,4 +294,78 @@ void caffe_gpu_dot(const int n, const double* x, const double* y, CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out)); } +template <> +int caffe_hamming_distance(const int n, const float* x, + const float* y) { + int dist = 0; + for (int i = 0; i < n; ++i) { + dist += __builtin_popcount(static_cast(x[i]) ^ + static_cast(y[i])); + } + return dist; +} + +template <> +int caffe_hamming_distance(const int n, const double* x, + const double* y) { + int dist = 0; + for (int i = 0; i < n; ++i) { + dist += __builtin_popcountl(static_cast(x[i]) ^ + static_cast(y[i])); + } + return dist; +} + +template <> +float caffe_cpu_asum(const int n, const float* x) { + return cblas_sasum(n, x, 1); +} + +template <> +double caffe_cpu_asum(const int n, const double* x) { + return cblas_dasum(n, x, 1); +} + +template <> +void caffe_gpu_asum(const int n, const float* x, float* y) { + CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y)); +} + +template <> +void caffe_gpu_asum(const int n, const double* x, double* y) { + CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y)); +} + +INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign); +INSTANTIATE_CAFFE_CPU_UNARY_FUNC(signbit); +INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs); + +template <> +void caffe_cpu_scale(const int n, const float alpha, const float *x, + float* y) { + cblas_scopy(n, x, 1, y, 1); + cblas_sscal(n, alpha, y, 1); +} + +template <> +void caffe_cpu_scale(const int n, const double alpha, const double *x, + double* y) { + cblas_dcopy(n, x, 1, y, 1); + cblas_dscal(n, alpha, y, 1); +} + +template <> +void caffe_gpu_scale(const int n, const float alpha, const float *x, + float* y) { + CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1)); + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1)); +} + +template <> +void caffe_gpu_scale(const int n, const double alpha, const double *x, + double* y) { + CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1)); + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1)); +} + } // namespace caffe diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index b29a58abe7f..3ca5fea9a64 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -1,8 +1,10 @@ // Copyright 2013 Yangqing Jia +// Copyright 2014 kloudkl@github #include #include #include +#include // CUDA's, not caffe's, for fabs, signbit #include "caffe/common.hpp" #include "caffe/util/math_functions.hpp" @@ -34,5 +36,8 @@ void caffe_gpu_mul(const int N, const double* a, N, a, b, y); } +DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0))); +DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(signbit, y[index] = signbit(x[index])); +DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index])); } // namespace caffe