Skip to content

Commit

Permalink
Merge pull request BVLC#175 from drnikolaev/caffe-0.15-alloc-dealloc
Browse files Browse the repository at this point in the history
Fixed issue with redundant memory allocations/deallocations. @pooyadavoodi thank you for reviewing this!
  • Loading branch information
drnikolaev authored Jun 21, 2016
2 parents 5fb4c72 + 3e31492 commit 92e5c6d
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 73 deletions.
2 changes: 1 addition & 1 deletion include/caffe/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class Layer {
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// LOG(WARNING) << "Using CPU code as backup.";
return Forward_cpu(bottom, top);
Forward_cpu(bottom, top);
}

/**
Expand Down
36 changes: 35 additions & 1 deletion include/caffe/layers/cudnn_conv_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ namespace caffe {
*/
template <typename Dtype>
class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
// In iteration 0, use a small amount of memory in order to leave
// most of memory for allocating layer blobs.
// NOLINT_NEXT_LINE(build/storage_class)
const static size_t INITIAL_WORKSPACE_SIZE;
// Use 95% of available memory.
// Using all of memory may result in failure of workspace.reserve.
// NOLINT_NEXT_LINE(build/storage_class)
const static float MAX_WORKSPACE_RATIO;
// We update it on second Fwd/Bwd pass and we allocate it *once*
// when we start third pass. We might recompute it later if demand grows
// and/or we suddenly need to get extra memory for other needs.
static size_t WORKSPACE_SIZE;
// This is the workspace used by all Convolution layers one after another.
// We carry it global to prevent unnecessary allocations/deallocations
// because they hurt performance.
static GPUMemory::Workspace WORKSPACE;

public:
explicit CuDNNConvolutionLayer(const LayerParameter& param)
: ConvolutionLayer<Dtype>(param), handles_setup_(false),
Expand Down Expand Up @@ -64,7 +81,6 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
size_t *workspace_fwd_sizes_;
size_t *workspace_bwd_data_sizes_;
size_t *workspace_bwd_filter_sizes_;
GPUMemory::Workspace workspace;

private:
bool use_algo_seeker_;
Expand All @@ -85,7 +101,25 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {

bool use_reshape_;
bool initialized_cached_descs_;

void UpdateWorkspaceDemand(int size);

// This is current *demand*: it might be not yet allocated.
};

template<typename Dtype>
size_t CuDNNConvolutionLayer<Dtype>::WORKSPACE_SIZE = 0UL;

template<typename Dtype>
const size_t CuDNNConvolutionLayer<Dtype>::INITIAL_WORKSPACE_SIZE =
4*1024*1024;

template<typename Dtype>
GPUMemory::Workspace CuDNNConvolutionLayer<Dtype>::WORKSPACE;

template<typename Dtype>
const float CuDNNConvolutionLayer<Dtype>::MAX_WORKSPACE_RATIO = 0.95F;

#endif

} // namespace caffe
Expand Down
10 changes: 8 additions & 2 deletions include/caffe/util/gpu_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ struct GPUMemory {
template <class Any>
static void allocate(Any** ptr, size_t size,
cudaStream_t stream = cudaStreamDefault) {
CHECK(try_allocate(reinterpret_cast<void**>(ptr), size, stream));
if (!try_allocate(reinterpret_cast<void**>(ptr), size, stream)) {
LOG(FATAL) << "Out of memory: failed to allocate " << size << " bytes";
}
}

static void deallocate(void* ptr,
Expand Down Expand Up @@ -74,7 +76,11 @@ struct GPUMemory {
return status;
}

void reserve(size_t size) { CHECK(try_reserve(size)); }
void reserve(size_t size) {
if (!try_reserve(size)) {
LOG(FATAL) << "Out of memory: failed to allocate " << size << " bytes";
}
}

void release() {
if (mgr_.using_pool()) {
Expand Down
51 changes: 32 additions & 19 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
use_reshape_ = true;
// When true, cached bottom and conv descriptors need to be set.
initialized_cached_descs_ = false;
// In case of reusing it
WORKSPACE.release();
}

template <typename Dtype>
Expand Down Expand Up @@ -179,17 +181,15 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
if (use_modest_workspace_) {
// In iteration 0, use a small amount of memory in order to leave
// most of memory for allocating layer blobs.
// TODO: Read 8*1024*1024 from a data member variable.
workspace_bytes = 8*1024*1024;
workspace_bytes = INITIAL_WORKSPACE_SIZE;
} else {
// Use 90% of available memory.
// Use 95% of available memory.
// Using all of memory may result in failure of workspace.reserve.
// TODO: Since 90% of memory might be too large, we can allocate
// TODO: Since 95% of memory might be too large, we can allocate
// exactly how much FindEx needs by taking the maximum
// workspace among all algorithms (requires an initial call
// to FindEx with workspace size 0).
// TODO: Read 0.9 from a data member variable.
workspace_bytes = workspace_limit_bytes * 0.9;
workspace_bytes = workspace_limit_bytes * MAX_WORKSPACE_RATIO;
// Avoid seeking for an algorithm in subsequent iterations
use_algo_seeker_ = false;
}
Expand Down Expand Up @@ -233,8 +233,9 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
Caffe::cudnn_handle(),
filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) );
bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]));
}
UpdateWorkspaceDemand(bottom.size()); // update WORKSPACE_SIZE

// Tensor descriptor for bias.
if (this->bias_term_) {
Expand Down Expand Up @@ -292,11 +293,7 @@ void CuDNNConvolutionLayer<Dtype>::FindExConvAlgo(
void *tmp_weights;
const int tmp_weights_size = sizeof(Dtype) * weight_offset_;
GPUMemory::allocate(&tmp_weights, tmp_weights_size);

// TODO: Try reducing workspace_bytes if it fails.
// In case, workspace_bytes is 90% of available memory,
// reduce it to 75%; if it fails again, reduce it to 50% and so on.
workspace.reserve(workspace_bytes);
WORKSPACE.reserve(workspace_bytes);

for (int i = 0; i < bottom.size(); i++) {
// Find forward algorithm
Expand All @@ -312,8 +309,8 @@ void CuDNNConvolutionLayer<Dtype>::FindExConvAlgo(
kRequestAlgoCount,
&fwd_algo_count,
fwd_results,
workspace.data(),
workspace.size()));
WORKSPACE.data(),
WORKSPACE.size()));
fwd_algo_[i] = fwd_results[0].algo;
workspace_fwd_sizes_[i] = fwd_results[0].memory;

Expand All @@ -332,8 +329,8 @@ void CuDNNConvolutionLayer<Dtype>::FindExConvAlgo(
kRequestAlgoCount,
&filter_algo_count,
bwd_filter_results,
workspace.data(),
workspace.size()));
WORKSPACE.data(),
WORKSPACE.size()));
bwd_filter_algo_[i] = bwd_filter_results[0].algo;
workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory;

Expand All @@ -350,15 +347,14 @@ void CuDNNConvolutionLayer<Dtype>::FindExConvAlgo(
kRequestAlgoCount,
&data_algo_count,
bwd_data_results,
workspace.data(),
workspace.size()));
WORKSPACE.data(),
WORKSPACE.size()));

bwd_data_algo_[i] = bwd_data_results[0].algo;
workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory;
}
}
GPUMemory::deallocate(tmp_weights);
workspace.release();
}
#endif

Expand Down Expand Up @@ -453,8 +449,25 @@ bool CuDNNConvolutionLayer<Dtype>::IsConvDescChanged(
return false;
}

template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::UpdateWorkspaceDemand(int size) {
// Updating the maximum WORKSPACE_SIZE
for (int i = 0; i < size; ++i) {
if (workspace_fwd_sizes_[i] > WORKSPACE_SIZE) {
WORKSPACE_SIZE = workspace_fwd_sizes_[i];
}
if (workspace_bwd_filter_sizes_[i] > WORKSPACE_SIZE) {
WORKSPACE_SIZE = workspace_bwd_filter_sizes_[i];
}
if (workspace_bwd_data_sizes_[i] > WORKSPACE_SIZE) {
WORKSPACE_SIZE = workspace_bwd_data_sizes_[i];
}
}
}

template <typename Dtype>
CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
WORKSPACE.release();
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

Expand Down
44 changes: 13 additions & 31 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,14 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* top_data = top[i]->mutable_gpu_data();

// Test free space and force reshape if allocations have changed
size_t workspace_limit_bytes, total_memory;
GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory);
if (workspace_fwd_sizes_[i] > workspace_limit_bytes) {
use_algo_seeker_ = true;
this->Reshape(bottom, top);
}
// Sometimes closer to zero we might have memory info diverged from reality
// If try_reserve fails, it updates the info internally and we proceed with
// Reshape one more time
if (!workspace.try_reserve(workspace_fwd_sizes_[i])) {
// Note: if WORKSPACE_SIZE is already allocated next line does nothing.
if (!WORKSPACE.try_reserve(WORKSPACE_SIZE)) {
use_algo_seeker_ = true;
this->Reshape(bottom, top);
workspace.reserve(workspace_fwd_sizes_[i]);
WORKSPACE.reserve(WORKSPACE_SIZE);
}

// Forward through cuDNN in parallel over groups.
Expand All @@ -44,7 +38,7 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + this->weight_offset_ * g,
conv_descs_[i],
fwd_algo_[i], workspace.data(), workspace.size(),
fwd_algo_[i], WORKSPACE.data(), WORKSPACE.size(),
cudnn::dataType<Dtype>::zero,
top_descs_[i], top_data + top_offset_ * g));

Expand All @@ -59,14 +53,11 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
}
}

workspace.release();
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
}
// Possibly use faster algorithms by allowing larger workspace.
use_modest_workspace_ = false;
}

template<typename Dtype>
Expand All @@ -84,25 +75,15 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
}
for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();
// Test free space and force reshape if allocations have changed
size_t workspace_limit_bytes, total_memory;
GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory);
if (workspace_bwd_filter_sizes_[i] > workspace_limit_bytes ||
workspace_bwd_data_sizes_[i] > workspace_limit_bytes) {
use_algo_seeker_ = true;
this->Reshape(bottom, top);
}
// To remove pressure on allocator, allocate the larger of the
// workspaces needed for the following steps

// Sometimes closer to zero we might have memory info diverged from reality
// If try_reserve fails, it updates the info internally and we proceed with
// Reshape one more time
if (!workspace.try_reserve(std::max(workspace_bwd_filter_sizes_[i],
workspace_bwd_data_sizes_[i]))) {
// Reshape one more time.
// Note: if WORKSPACE_SIZE is already allocated next line does nothing.
if (!WORKSPACE.try_reserve(WORKSPACE_SIZE)) {
use_algo_seeker_ = true;
this->Reshape(bottom, top);
workspace.reserve(std::max(workspace_bwd_filter_sizes_[i],
workspace_bwd_data_sizes_[i]));
WORKSPACE.reserve(WORKSPACE_SIZE);
}

// Backward through cuDNN in parallel over groups and gradients.
Expand All @@ -123,7 +104,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
bwd_filter_algo_[i], workspace.data(), workspace.size(),
bwd_filter_algo_[i], WORKSPACE.data(), WORKSPACE.size(),
cudnn::dataType<Dtype>::one,
filter_desc_, weight_diff + this->weight_offset_ * g));
}
Expand All @@ -138,18 +119,19 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
filter_desc_, weight + this->weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
bwd_data_algo_[i], workspace.data(), workspace.size(),
bwd_data_algo_[i], WORKSPACE.data(), WORKSPACE.size(),
cudnn::dataType<Dtype>::zero,
bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}

workspace.release();
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
}
// Possibly use faster algorithms by allowing larger workspace.
use_modest_workspace_ = false;
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNConvolutionLayer);
Expand Down
37 changes: 18 additions & 19 deletions tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ int time() {
// Do a number of clean forward and backward pass,
// so that memory allocation are done,
// and future iterations will be more stable.
Timer init_timer;
Timer forward_timer;
Timer backward_timer;
double forward_time = 0.0;
Expand All @@ -365,30 +366,28 @@ int time() {
LOG(INFO) << "Initialization for " << kInitIterations << " iterations.";
// Note that for the speed benchmark, we will assume that the network does
// not take any input blobs.
LOG(INFO) << "Performing Forward";
float initial_loss;
forward_timer.Start();
for (int j = 0; j < kInitIterations; ++j) {
caffe_net.Forward(&initial_loss);
}
forward_time += forward_timer.MicroSeconds();
LOG(INFO) << "Initial loss: " << initial_loss;
LOG(INFO) << "Performing Backward";
backward_timer.Start();
for (int j = 0; j < kInitIterations; ++j) {
caffe_net.Backward();
}
backward_time += backward_timer.MicroSeconds();
LOG(INFO) << "Average Initialization Forward pass: " << forward_time /
1000 / kInitIterations << " ms.";
LOG(INFO) << "Average Initialization Backward pass: " << backward_time /
1000 / kInitIterations << " ms.";

LOG(INFO) << "Performing initial Forward/Backward";
const vector<shared_ptr<Layer<float> > >& layers = caffe_net.layers();
const vector<vector<Blob<float>*> >& bottom_vecs = caffe_net.bottom_vecs();
const vector<vector<Blob<float>*> >& top_vecs = caffe_net.top_vecs();
const vector<vector<bool> >& bottom_need_backward =
caffe_net.bottom_need_backward();
float initial_loss = 0.F;
init_timer.Start();
for (int j = 0; j < kInitIterations; ++j) {
for (int i = 0; i < layers.size(); ++i) {
initial_loss += layers[i]->Forward(bottom_vecs[i], top_vecs[i]);
}
for (int i = layers.size() - 1; i >= 0; --i) {
layers[i]->Backward(top_vecs[i], bottom_need_backward[i],
bottom_vecs[i]);
}
}
double init_time = init_timer.MilliSeconds();
LOG(INFO) << "Initial Forward/Backward complete, loss: " << initial_loss;
LOG(INFO) << "Average Initialization Forward/Backward pass: " << init_time /
kInitIterations << " ms.";

LOG(INFO) << "*** Benchmark begins ***";
LOG(INFO) << "Testing for " << FLAGS_iterations << " iterations.";
Timer total_timer;
Expand Down

0 comments on commit 92e5c6d

Please sign in to comment.