Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MKL/non-MKL merge #97

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cudart cublas curand mkl_rt pthread \
LIBRARIES := cudart cublas curand pthread \
glog protobuf leveldb snappy boost_system \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
Expand All @@ -82,6 +82,14 @@ LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))

# MKL options
ifdef USE_MKL
LIBRARIES += mkl_rt
COMMON_FLAGS += -DUSE_MKL
else
LIBRARIES += atlas cblas
endif


##############################
# Define build targets
Expand Down
2 changes: 2 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \
-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35

# If not using MKL, comment out the following line.
# USE_MKL=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if the

ifdef USE_MKL

in Makefile will be false if USE_MKL=0 (I suspect not)? There might be a
hidden bug here by making the if statement ambiguous. Essentially
we should be checking the value of USE_MKL instead of checking whether it
is defined.

I am sure you are more familiar with Makefile jargons so please kindly fix
it :)

Yangqing

On Fri, Feb 14, 2014 at 1:19 PM, Evan Shelhamer notifications@github.comwrote:

In Makefile.config.example:

@@ -7,6 +7,8 @@ CUDA_ARCH := -gencode arch=compute_20,code=sm_20
-gencode arch=compute_30,code=sm_30
-gencode arch=compute_35,code=sm_35

+# If not using MKL, comment out the following line.
+# USE_MKL=1

TODO this should be =0. I'll fix during merge (real soon now).

Reply to this email directly or view it on GitHubhttps://github.com//pull/97/files#r9764360
.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shouldn't work on Caffe before I have my coffee :) I'll make sure it's right when I merge.

# MKL directory contains include/ and lib/ directions that we need.
MKL_DIR := /opt/intel/mkl

Expand Down
8 changes: 8 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff
Expand Down
14 changes: 11 additions & 3 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_

#include <boost/random/mersenne_twister.hpp>
#include <boost/shared_ptr.hpp>
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
// cuda driver types
#include <driver_types.h>
#include <glog/logging.h>
#include <mkl_vsl.h>
//#include <mkl_vsl.h>

// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
Expand Down Expand Up @@ -83,8 +84,13 @@ class Caffe {
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}

// Returns the MKL random stream.
inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
//inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }

typedef boost::mt19937 random_generator_t;
inline static random_generator_t &vsl_stream() { return Get().random_generator_; }

// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
Expand All @@ -108,7 +114,9 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
//VSLStreamStatePtr vsl_stream_;
random_generator_t random_generator_;

Brew mode_;
Phase phase_;
static shared_ptr<Caffe> singleton_;
Expand Down
2 changes: 1 addition & 1 deletion include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifndef CAFFE_FILLER_HPP
#define CAFFE_FILLER_HPP

#include <mkl.h>
//#include <mkl.h>
#include <string>

#include "caffe/common.hpp"
Expand Down
12 changes: 10 additions & 2 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_

#include <mkl.h>

#include <cublas_v2.h>

#include "caffe/util/mkl_alternate.hpp"

namespace caffe {

// Decaf gemm provides a simpler interface to the gemm functions, with the
Expand Down Expand Up @@ -44,7 +46,7 @@ void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);

template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);

template <typename Dtype>
Expand Down Expand Up @@ -84,13 +86,19 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);

template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);

template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);

template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);

template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);

Expand Down
95 changes: 95 additions & 0 deletions include/caffe/util/mkl_alternate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2013 Rowland Depp

#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
#define CAFFE_UTIL_MKL_ALTERNATE_H_

#ifdef USE_MKL

#include <mkl.h>

#else // If use MKL, simply include the MKL header

#include <cblas.h>
#include <math.h>

// Functions that caffe uses but are not present if MKL is not linked.

// A simple way to define the vsl unary functions. The operation should
// be in the form e.g. y[i] = sqrt(a[i])
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, float* y) { \
v##name<float>(n, a, y); \
} \
inline void vd##name( \
const int n, const double* a, double* y) { \
v##name<double>(n, a, y); \
}

DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));

// A simple way to define the vsl unary functions with singular parameter b.
// The operation should be in the form e.g. y[i] = pow(a[i], b)
#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const float b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));

// A simple way to define the vsl binary functions. The operation should
// be in the form e.g. y[i] = a[i] + b[i]
#define DEFINE_VSL_BINARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float* b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const double* b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);

// In addition, MKL comes with an additional function axpby that is not present
// in standard blas. We will simply use a two-step (inefficient, of course) way
// to mimic that.
inline void cblas_saxpby(const int N, const float alpha, const float* X,
const int incX, const float beta, float* Y,
const int incY) {
cblas_sscal(N, beta, Y, incY);
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_daxpby(const int N, const double alpha, const double* X,
const int incX, const double beta, double* Y,
const int incY) {
cblas_dscal(N, beta, Y, incY);
cblas_daxpy(N, alpha, X, incX, Y, incY);
}

#endif // USE_MKL
#endif // CAFFE_UTIL_MKL_ALTERNATE_H_
24 changes: 15 additions & 9 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ long cluster_seedgen(void) {

Caffe::Caffe()
: mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
curand_generator_(NULL), vsl_stream_(NULL) {
curand_generator_(NULL),
//vsl_stream_(NULL)
random_generator_()
{
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -34,21 +37,22 @@ Caffe::Caffe()
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}

// Try to create a vsl stream. This should almost always work, but we will
// check it anyway.
if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
<< "won't be available.";
}
//if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, cluster_seedgen()) != VSL_STATUS_OK) {
// LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
// << "won't be available.";
//}
}

Caffe::~Caffe() {
if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
};
//if (vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
}

void Caffe::set_random_seed(const unsigned int seed) {
// Curand seed
Expand All @@ -64,8 +68,10 @@ void Caffe::set_random_seed(const unsigned int seed) {
LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
}
// VSL seed
VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
//VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
//VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
Get().random_generator_ = random_generator_t(seed);

}

void Caffe::SetDevice(const int device_id) {
Expand Down
7 changes: 5 additions & 2 deletions src/caffe/layers/dropout_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <limits>

#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"
Expand Down Expand Up @@ -34,8 +35,10 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int count = bottom[0]->count();
if (Caffe::phase() == Caffe::TRAIN) {
// Create random numbers
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
count, mask, 1. - threshold_);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// count, mask, 1. - threshold_);
caffe_vRngBernoulli<int>(count, mask, 1. - threshold_);

for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2013 Yangqing Jia


#include <mkl.h>
//#include <mkl.h>
#include <cublas_v2.h>

#include <vector>
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/loss_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Dtype EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
Dtype loss = caffe_cpu_dot(
count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2);
// Compute the gradient
caffe_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
caffe_cpu_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
(*bottom)[0]->mutable_cpu_diff());
return loss;
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_axpby(net_params[param_id]->count(), local_rate,
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
if (local_decay) {
Expand Down
17 changes: 11 additions & 6 deletions src/caffe/test/test_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"

#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"

namespace caffe {
Expand All @@ -20,7 +20,8 @@ TEST_F(CommonTest, TestCublasHandler) {
}

TEST_F(CommonTest, TestVslStream) {
EXPECT_TRUE(Caffe::vsl_stream());
//EXPECT_TRUE(Caffe::vsl_stream());
EXPECT_TRUE(true);
}

TEST_F(CommonTest, TestBrewMode) {
Expand All @@ -39,11 +40,15 @@ TEST_F(CommonTest, TestRandSeedCPU) {
SyncedMemory data_a(10 * sizeof(int));
SyncedMemory data_b(10 * sizeof(int));
Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, (int*)data_a.mutable_cpu_data(), 0.5);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// 10, (int*)data_a.mutable_cpu_data(), 0.5);
caffe_vRngBernoulli(10, (int*)data_a.mutable_cpu_data(), 0.5);

Caffe::set_random_seed(1701);
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
10, (int*)data_b.mutable_cpu_data(), 0.5);
//viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffe::vsl_stream(),
// 10, (int*)data_b.mutable_cpu_data(), 0.5);
caffe_vRngBernoulli(10, (int*)data_b.mutable_cpu_data(), 0.5);

for (int i = 0; i < 10; ++i) {
EXPECT_EQ(((const int*)(data_a.cpu_data()))[i],
((const int*)(data_b.cpu_data()))[i]);
Expand Down
3 changes: 3 additions & 0 deletions src/caffe/test/test_flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FlattenLayerTest : public ::testing::Test {
FlattenLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
Expand Down Expand Up @@ -72,6 +73,8 @@ TYPED_TEST(FlattenLayerTest, TestGPU) {
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0),
this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

Expand Down
Loading