Skip to content

Commit

Permalink
Merge pull request #159 from kloudkl/math_functions
Browse files Browse the repository at this point in the history
Add more convenience math functions: hamming distance, sum of absolute values, elementwise sign and abs, and non-in-place scaling.
  • Loading branch information
shelhamer committed Mar 10, 2014
2 parents 32fb333 + c9d9056 commit 980c00d
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 0 deletions.
81 changes: 81 additions & 0 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github

#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_

#include <cmath> // for std::fabs
#include <math.h> // for signbit
#include <mkl.h>
#include <cublas_v2.h>

Expand Down Expand Up @@ -100,6 +103,84 @@ Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
template <typename Dtype>
void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);

template <typename Dtype>
int caffe_hamming_distance(const int n, const Dtype* x, const Dtype* y);

// Returns the sum of the absolute values of the elements of vector x
template <typename Dtype>
Dtype caffe_cpu_asum(const int n, const Dtype* x);

template <typename Dtype>
void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y);

// the branchless, type-safe version from
// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c
template<typename Dtype>
inline char caffe_sign(Dtype val) {
return (Dtype(0) < val) - (val < Dtype(0));
}

// The following two macros are modifications of DEFINE_VSL_UNARY_FUNC
// in include/caffe/util/mkl_alternate.hpp authored by @Rowland Depp.
// Please refer to commit 7e8ef25c7 of the boost-eigen branch.
#define DEFINE_CAFFE_CPU_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void caffe_cpu_##name(const int n, const Dtype* x, Dtype* y) { \
CHECK_GT(n, 0); CHECK(x); CHECK(y); \
for (int i = 0; i < n; ++i) { \
operation; \
} \
}

#define INSTANTIATE_CAFFE_CPU_UNARY_FUNC(name) \
template <> \
void caffe_cpu_##name<float>(const int n, const float* x, float* y); \
template <> \
void caffe_cpu_##name<double>(const int n, const double* x, double* y)


#define DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(name, operation) \
template<typename Dtype> \
__global__ void name##_kernel(const int n, const Dtype* x, Dtype* y) { \
int index = threadIdx.x + blockIdx.x * blockDim.x; \
if (index < n) { \
operation; \
} \
} \
template <> \
void caffe_gpu_##name<float>(const int n, const float* x, float* y) { \
name##_kernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
n, x, y); \
} \
template <> \
void caffe_gpu_##name<double>(const int n, const double* x, double* y) { \
name##_kernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>( \
n, x, y); \
}

// output is 1 for the positives, 0 for zero, and -1 for the negatives
DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign<Dtype>(x[i]));

template<typename Dtype>
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);

// returns a nonzero value is the input has its sign bit set.
DEFINE_CAFFE_CPU_UNARY_FUNC(signbit, y[i] = std::signbit(x[i]));

template<typename Dtype>
void caffe_gpu_signbit(const int n, const Dtype* x, Dtype* y);

DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i]));

template <typename Dtype>
void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y);

template <typename Dtype>
void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y);

template <typename Dtype>
void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y);

} // namespace caffe


Expand Down
190 changes: 190 additions & 0 deletions src/caffe/test/test_math_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright 2014 kloudkl@github

#include <stdint.h> // for uint32_t & uint64_t
#include <cmath> // for std::fabs

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/util/math_functions.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

template<typename Dtype>
class MathFunctionsTest : public ::testing::Test {
protected:
MathFunctionsTest()
: blob_bottom_(new Blob<Dtype>()),
blob_top_(new Blob<Dtype>()) {
}

virtual void SetUp() {
Caffe::set_random_seed(1701);
this->blob_bottom_->Reshape(100, 70, 50, 30);
this->blob_top_->Reshape(100, 70, 50, 30);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
filler.Fill(this->blob_top_);
}

virtual ~MathFunctionsTest() {
delete blob_bottom_;
delete blob_top_;
}
// http://en.wikipedia.org/wiki/Hamming_distance
int ReferenceHammingDistance(const int n, const Dtype* x, const Dtype* y);

Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
};

#define REF_HAMMING_DIST(float_type, int_type) \
template<> \
int MathFunctionsTest<float_type>::ReferenceHammingDistance(const int n, \
const float_type* x, \
const float_type* y) { \
int dist = 0; \
int_type val; \
for (int i = 0; i < n; ++i) { \
val = static_cast<int_type>(x[i]) ^ static_cast<int_type>(y[i]); \
/* Count the number of set bits */ \
while (val) { \
++dist; \
val &= val - 1; \
} \
} \
return dist; \
}

REF_HAMMING_DIST(float, uint32_t);
REF_HAMMING_DIST(double, uint64_t);

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(MathFunctionsTest, Dtypes);

TYPED_TEST(MathFunctionsTest, TestHammingDistance){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
const TypeParam* y = this->blob_top_->cpu_data();
CHECK_EQ(this->ReferenceHammingDistance(n, x, y),
caffe_hamming_distance<TypeParam>(n, x, y));
}

TYPED_TEST(MathFunctionsTest, TestAsumCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
for (int i = 0; i < n; ++i) {
std_asum += std::fabs(x[i]);
}
TypeParam cpu_asum = caffe_cpu_asum<TypeParam>(n, x);
CHECK_LT((cpu_asum - std_asum) / std_asum, 1e-2);
}

TYPED_TEST(MathFunctionsTest, TestAsumGPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
TypeParam std_asum = 0;
for (int i = 0; i < n; ++i) {
std_asum += std::fabs(x[i]);
}
TypeParam gpu_asum;
caffe_gpu_asum<TypeParam>(n, this->blob_bottom_->gpu_data(), &gpu_asum);
CHECK_LT((gpu_asum - std_asum) / std_asum, 1e-2);
}

TYPED_TEST(MathFunctionsTest, TestSignCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_sign<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}

TYPED_TEST(MathFunctionsTest, TestSignGPU){
int n = this->blob_bottom_->count();
caffe_gpu_sign<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* signs = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0));
}
}

TYPED_TEST(MathFunctionsTest, TestSignbitCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_signbit<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* signbits = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
}
}

TYPED_TEST(MathFunctionsTest, TestSignbitGPU){
int n = this->blob_bottom_->count();
caffe_gpu_signbit<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* signbits = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(signbits[i], x[i] < 0 ? 1 : 0);
}
}

TYPED_TEST(MathFunctionsTest, TestFabsCPU){
int n = this->blob_bottom_->count();
const TypeParam* x = this->blob_bottom_->cpu_data();
caffe_cpu_fabs<TypeParam>(n, x, this->blob_bottom_->mutable_cpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}

TYPED_TEST(MathFunctionsTest, TestFabsGPU){
int n = this->blob_bottom_->count();
caffe_gpu_fabs<TypeParam>(n, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* abs_val = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]);
}
}

TYPED_TEST(MathFunctionsTest, TestScaleCPU){
int n = this->blob_bottom_->count();
TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() %
this->blob_bottom_->count()];
caffe_cpu_scale<TypeParam>(n, alpha, this->blob_bottom_->cpu_data(),
this->blob_bottom_->mutable_cpu_diff());
const TypeParam* scaled = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(scaled[i], x[i] * alpha);
}
}

TYPED_TEST(MathFunctionsTest, TestScaleGPU){
int n = this->blob_bottom_->count();
TypeParam alpha = this->blob_bottom_->cpu_diff()[rand() %
this->blob_bottom_->count()];
caffe_gpu_scale<TypeParam>(n, alpha, this->blob_bottom_->gpu_data(),
this->blob_bottom_->mutable_gpu_diff());
const TypeParam* scaled = this->blob_bottom_->cpu_diff();
const TypeParam* x = this->blob_bottom_->cpu_data();
for (int i = 0; i < n; ++i) {
CHECK_EQ(scaled[i], x[i] * alpha);
}
}

}
75 changes: 75 additions & 0 deletions src/caffe/util/math_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github

#include <mkl.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -293,4 +294,78 @@ void caffe_gpu_dot<double>(const int n, const double* x, const double* y,
CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
}

template <>
int caffe_hamming_distance<float>(const int n, const float* x,
const float* y) {
int dist = 0;
for (int i = 0; i < n; ++i) {
dist += __builtin_popcount(static_cast<uint32_t>(x[i]) ^
static_cast<uint32_t>(y[i]));
}
return dist;
}

template <>
int caffe_hamming_distance<double>(const int n, const double* x,
const double* y) {
int dist = 0;
for (int i = 0; i < n; ++i) {
dist += __builtin_popcountl(static_cast<uint64_t>(x[i]) ^
static_cast<uint64_t>(y[i]));
}
return dist;
}

template <>
float caffe_cpu_asum<float>(const int n, const float* x) {
return cblas_sasum(n, x, 1);
}

template <>
double caffe_cpu_asum<double>(const int n, const double* x) {
return cblas_dasum(n, x, 1);
}

template <>
void caffe_gpu_asum<float>(const int n, const float* x, float* y) {
CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y));
}

template <>
void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y));
}

INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(signbit);
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);

template <>
void caffe_cpu_scale<float>(const int n, const float alpha, const float *x,
float* y) {
cblas_scopy(n, x, 1, y, 1);
cblas_sscal(n, alpha, y, 1);
}

template <>
void caffe_cpu_scale<double>(const int n, const double alpha, const double *x,
double* y) {
cblas_dcopy(n, x, 1, y, 1);
cblas_dscal(n, alpha, y, 1);
}

template <>
void caffe_gpu_scale<float>(const int n, const float alpha, const float *x,
float* y) {
CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1));
CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1));
}

template <>
void caffe_gpu_scale<double>(const int n, const double alpha, const double *x,
double* y) {
CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1));
CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1));
}

} // namespace caffe
5 changes: 5 additions & 0 deletions src/caffe/util/math_functions.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github

#include <cmath>
#include <cstdlib>
#include <cstring>
#include <math_functions.h> // CUDA's, not caffe's, for fabs, signbit

#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
Expand Down Expand Up @@ -34,5 +36,8 @@ void caffe_gpu_mul<double>(const int N, const double* a,
N, a, b, y);
}

DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0)));
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(signbit, y[index] = signbit(x[index]));
DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(fabs, y[index] = fabs(x[index]));

} // namespace caffe

0 comments on commit 980c00d

Please sign in to comment.