Skip to content

Commit

Permalink
Merge pull request #165 from BVLC/boost-eigen
Browse files Browse the repository at this point in the history
MKL/non-MKL Reconciliation

Caffe no longer requires MKL. By default it builds without it, relying on atlas and cblas instead. Set the `USE_MKL` var in your Makefile.config accordingly.
  • Loading branch information
shelhamer committed Mar 23, 2014
2 parents 510b3c0 + bece205 commit 699b557
Show file tree
Hide file tree
Showing 21 changed files with 493 additions and 143 deletions.
24 changes: 17 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,37 @@ CUDA_LIB_DIR := $(CUDA_DIR)/lib64 $(CUDA_DIR)/lib
MKL_INCLUDE_DIR := $(MKL_DIR)/include
MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR)
LIBRARIES := cudart cublas curand \
mkl_rt \
pthread \
glog protobuf leveldb \
snappy \
glog protobuf leveldb snappy \
boost_system \
hdf5_hl hdf5 \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall

COMMON_FLAGS := -DNDEBUG -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
COMMON_FLAGS := -DNDEBUG -O2

# MKL switch (default = non-MKL)
USE_MKL ?= 0
ifeq ($(USE_MKL), 1)
LIBRARIES += mkl_rt
COMMON_FLAGS += -DUSE_MKL
INCLUDE_DIRS += $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(MKL_LIB_DIR)
else
LIBRARIES += cblas atlas
endif

COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS)
NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))


##############################
# Define build targets
##############################
Expand Down
2 changes: 2 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \
-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35

# MKL switch: set to 1 for MKL
USE_MKL := 0
# MKL directory contains include/ and lib/ directions that we need.
MKL_DIR := /opt/intel/mkl

Expand Down
8 changes: 8 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff
Expand Down
102 changes: 59 additions & 43 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 BVLC and contributors.

#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
Expand All @@ -7,28 +7,8 @@
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
// cuda driver types
#include <driver_types.h>
#include <driver_types.h> // cuda driver types
#include <glog/logging.h>
#include <mkl_vsl.h>

// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)

#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

// After a kernel is executed, this will check the error and if there is one,
// exit loudly.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())

// Disable the copy and assignment operator for a class.
#define DISABLE_COPY_AND_ASSIGN(classname) \
Expand All @@ -45,6 +25,23 @@ private:\
// is executed we will see a fatal log.
#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"

// CUDA: various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

// CUDA: check for error after kernel execution and exit loudly if there is one.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())


namespace caffe {

Expand All @@ -53,20 +50,6 @@ namespace caffe {
using boost::shared_ptr;


// We will use 1024 threads per block, which requires cuda sm_2x or above.
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif



inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}


// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
Expand All @@ -81,15 +64,32 @@ class Caffe {
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };

// The getters for the variables.
// Returns the cublas handle.

// This random number generator facade hides boost and CUDA rng
// implementation from one another (for cross-platform compatibility).
class RNG {
public:
RNG();
explicit RNG(unsigned int seed);
~RNG();
RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void* generator();
private:
class Generator;
Generator* generator_;
};

// Getters for boost rng, curand, and cublas handles
inline static RNG &rng_stream() {
return Get().random_generator_;
}
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
// Returns the curand generator.
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
// Returns the MKL random stream.
inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }

// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
Expand All @@ -102,7 +102,7 @@ class Caffe {
inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
Expand All @@ -113,7 +113,8 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
RNG random_generator_;

Brew mode_;
Phase phase_;
static shared_ptr<Caffe> singleton_;
Expand All @@ -126,6 +127,21 @@ class Caffe {
};


// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif

// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}


} // namespace caffe

#endif // CAFFE_COMMON_HPP_
1 change: 0 additions & 1 deletion include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#ifndef CAFFE_FILLER_HPP
#define CAFFE_FILLER_HPP

#include <mkl.h>
#include <string>

#include "caffe/common.hpp"
Expand Down
12 changes: 10 additions & 2 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_

#include <mkl.h>

#include <cublas_v2.h>

#include "caffe/util/mkl_alternate.hpp"

namespace caffe {

// Decaf gemm provides a simpler interface to the gemm functions, with the
Expand Down Expand Up @@ -45,7 +47,7 @@ void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);

template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);

template <typename Dtype>
Expand Down Expand Up @@ -85,13 +87,19 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);

template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);

template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);

template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);

template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);

Expand Down
97 changes: 97 additions & 0 deletions include/caffe/util/mkl_alternate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2013 Rowland Depp

#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
#define CAFFE_UTIL_MKL_ALTERNATE_H_

#ifdef USE_MKL

#include <mkl.h>

#else // If use MKL, simply include the MKL header

extern "C" {
#include <cblas.h>
}
#include <math.h>

// Functions that caffe uses but are not present if MKL is not linked.

// A simple way to define the vsl unary functions. The operation should
// be in the form e.g. y[i] = sqrt(a[i])
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, float* y) { \
v##name<float>(n, a, y); \
} \
inline void vd##name( \
const int n, const double* a, double* y) { \
v##name<double>(n, a, y); \
}

DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));

// A simple way to define the vsl unary functions with singular parameter b.
// The operation should be in the form e.g. y[i] = pow(a[i], b)
#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const float b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));

// A simple way to define the vsl binary functions. The operation should
// be in the form e.g. y[i] = a[i] + b[i]
#define DEFINE_VSL_BINARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float* b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const double* b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);

// In addition, MKL comes with an additional function axpby that is not present
// in standard blas. We will simply use a two-step (inefficient, of course) way
// to mimic that.
inline void cblas_saxpby(const int N, const float alpha, const float* X,
const int incX, const float beta, float* Y,
const int incY) {
cblas_sscal(N, beta, Y, incY);
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_daxpby(const int N, const double alpha, const double* X,
const int incX, const double beta, double* Y,
const int incY) {
cblas_dscal(N, beta, Y, incY);
cblas_daxpy(N, alpha, X, incX, Y, incY);
}

#endif // USE_MKL
#endif // CAFFE_UTIL_MKL_ALTERNATE_H_
19 changes: 19 additions & 0 deletions include/caffe/util/rng.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2014 BVLC and contributors.

#ifndef CAFFE_RNG_CPP_HPP_
#define CAFFE_RNG_CPP_HPP_

#include <boost/random/mersenne_twister.hpp>
#include "caffe/common.hpp"

namespace caffe {

typedef boost::mt19937 rng_t;
inline rng_t& caffe_rng() {
Caffe::RNG &generator = Caffe::rng_stream();
return *(caffe::rng_t*) generator.generator();
}

} // namespace caffe

#endif // CAFFE_RNG_HPP_
Loading

0 comments on commit 699b557

Please sign in to comment.