Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dynamic batch sizes in all the layers #195

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,23 @@ template <typename Dtype>
class Blob {
public:
Blob()
: num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
diff_() {}
: data_(), diff_(), num_(0), channels_(0), height_(0), width_(0),
count_(0), capacity_(0) {}
explicit Blob(const int num, const int channels, const int height,
const int width);
virtual ~Blob() {}
void Reshape(const int num, const int height,
const int width, const int channels);
void Reshape(const int num, const int channels, const int height,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making this consistent with the internal indexing by N x K x H x W.

const int width);
// Only reshape the num while keeping the other three dims intact
void ReshapeNum(const int num);
// Re-allocate memory if the current blob capacity is not big enough
void Reserve(const size_t capacity);
inline int num() const { return num_; }
inline int channels() const { return channels_; }
inline int height() const { return height_; }
inline int width() const { return width_; }
inline int count() const {return count_; }
inline int count() const { return count_; }
inline size_t capacity() const { return capacity_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
return ((n * channels_ + c) * height_ + h) * width_ + w;
Expand Down Expand Up @@ -52,6 +57,20 @@ class Blob {
Dtype* mutable_gpu_data();
Dtype* mutable_cpu_diff();
Dtype* mutable_gpu_diff();
inline size_t data_size() const { return data_->size(); }
inline size_t diff_size() const { return diff_->size(); }
inline bool has_data() const {
if (data_) {
return true;
}
return false;
}
inline bool has_diff() const {
if (diff_) {
return true;
}
return false;
}
void Update();
void FromProto(const BlobProto& proto);
void ToProto(BlobProto* proto, bool write_diff = false) const;
Expand All @@ -64,6 +83,7 @@ class Blob {
int height_;
int width_;
int count_;
size_t capacity_;

DISABLE_COPY_AND_ASSIGN(Blob);
}; // class Blob
Expand Down
4 changes: 2 additions & 2 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class SyncedMemory {
void* mutable_gpu_data();
enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED };
SyncedHead head() { return head_; }
size_t size() { return size_; }
size_t size() const { return size_; }
private:
void to_cpu();
void to_gpu();
void* cpu_ptr_;
void* gpu_ptr_;
size_t size_;
const size_t size_;
SyncedHead head_;

DISABLE_COPY_AND_ASSIGN(SyncedMemory);
Expand Down
34 changes: 25 additions & 9 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <algorithm> // std::max

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {
using std::max;

template <typename Dtype>
Blob<Dtype>::Blob(const int num, const int channels, const int height,
const int width) {
Reshape(num, channels, height, width);
}

template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
Expand All @@ -22,21 +30,29 @@ void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
if (count_) {
data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
Reserve(count_);
}

template <typename Dtype>
void Blob<Dtype>::ReshapeNum(const int num) {
Reshape(num, channels_, height_, width_);
}

template <typename Dtype>
void Blob<Dtype>::Reserve(const size_t capacity) {
if (capacity) {
if (capacity_ < capacity) {
capacity_ = capacity;
data_.reset(new SyncedMemory(capacity * sizeof(Dtype)));
diff_.reset(new SyncedMemory(capacity * sizeof(Dtype)));
}
} else {
capacity_ = 0;
data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
}
}

template <typename Dtype>
Blob<Dtype>::Blob(const int num, const int channels, const int height,
const int width) {
Reshape(num, channels, height, width);
}

template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
Expand Down
7 changes: 4 additions & 3 deletions src/caffe/layers/bnll_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright 2013 Yangqing Jia

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include <algorithm>
#include <vector>

using std::min;
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {
using std::min;

const float kBNLL_THRESHOLD = 50.;

Expand Down
23 changes: 12 additions & 11 deletions src/caffe/layers/concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void ConcatLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
}
(*top)[0]->Reshape(NUM_, CHANNELS_, HEIGHT_, WIDTH_);
CHECK_EQ(COUNT_, (*top)[0]->count());
CHECK_LE(COUNT_, (*top)[0]->count());
}

template <typename Dtype>
Expand All @@ -50,18 +50,19 @@ void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
int num_elem = bottom[i]->count();
caffe_copy(num_elem, bottom_data, top_data+(*top)[0]->offset(offset_num));
caffe_copy(num_elem, bottom_data,
top_data + (*top)[0]->offset(offset_num));
offset_num += bottom[i]->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
int num_elem =
bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n = 0; n < NUM_; ++n) {
caffe_copy(num_elem, bottom_data+bottom[i]->offset(n),
top_data+(*top)[0]->offset(n, offset_channel));
bottom[i]->channels() * bottom[i]->height() * bottom[i]->width();
for (int n = 0; n < bottom[i]->num(); ++n) {
caffe_copy(num_elem, bottom_data + bottom[i]->offset(n),
top_data + (*top)[0]->offset(n, offset_channel));
}
offset_channel += bottom[i]->channels();
}
Expand All @@ -81,18 +82,18 @@ Dtype ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_cpu_diff();
caffe_copy(blob->count(),
top_diff+top[0]->offset(offset_num), bottom_diff);
top_diff + top[0]->offset(offset_num), bottom_diff);
offset_num += blob->num();
}
} else if (concat_dim_ == 1) {
int offset_channel = 0;
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_cpu_diff();
int num_elem = blob->channels()*blob->height()*blob->width();
for (int n = 0; n < NUM_; ++n) {
caffe_copy(num_elem, top_diff+top[0]->offset(n, offset_channel),
bottom_diff+blob->offset(n));
int num_elem = blob->channels() * blob->height() * blob->width();
for (int n = 0; n < top[i]->num(); ++n) {
caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
bottom_diff + blob->offset(n));
}
offset_channel += blob->channels();
}
Expand Down
16 changes: 8 additions & 8 deletions src/caffe/layers/concat_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ void ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
int num_elem =
bottom[i]->channels()*bottom[i]->height()*bottom[i]->width();
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_copy(num_elem, bottom_data+bottom[i]->offset(n),
top_data+(*top)[0]->offset(n, offset_channel));
bottom[i]->channels() * bottom[i]->height() * bottom[i]->width();
for (int n = 0; n < bottom[i]->num(); ++n) {
caffe_gpu_copy(num_elem, bottom_data + bottom[i]->offset(n),
top_data + (*top)[0]->offset(n, offset_channel));
}
offset_channel += bottom[i]->channels();
}
Expand Down Expand Up @@ -56,10 +56,10 @@ Dtype ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
for (int i = 0; i < bottom->size(); ++i) {
Blob<Dtype>* blob = (*bottom)[i];
Dtype* bottom_diff = blob->mutable_gpu_diff();
int num_elem = blob->channels()*blob->height()*blob->width();
for (int n = 0; n < NUM_; ++n) {
caffe_gpu_copy(num_elem, top_diff+top[0]->offset(n, offset_channel),
bottom_diff+blob->offset(n));
int num_elem = blob->channels() * blob->height() * blob->width();
for (int n = 0; n < top[i]->num(); ++n) {
caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
bottom_diff + blob->offset(n));
}
offset_channel += blob->channels();
}
Expand Down
6 changes: 3 additions & 3 deletions src/caffe/layers/conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
for (int n = 0; n < bottom[0]->num(); ++n) {
// First, im2col
im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
Expand Down Expand Up @@ -122,7 +122,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (biasterm_) {
bias_diff = this->blobs_[1]->mutable_cpu_diff();
memset(bias_diff, 0, sizeof(Dtype) * this->blobs_[1]->count());
for (int n = 0; n < NUM_; ++n) {
for (int n = 0; n < top[0]->num(); ++n) {
caffe_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
1., top_diff + top[0]->offset(n),
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), 1.,
Expand All @@ -134,7 +134,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
int col_offset = K_ * N_;
int top_offset = M_ * N_;
memset(weight_diff, 0, sizeof(Dtype) * this->blobs_[0]->count());
for (int n = 0; n < NUM_; ++n) {
for (int n = 0; n < top[0]->num(); ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
for (int n = 0; n < bottom[0]->num(); ++n) {
// First, im2col
im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
Expand Down Expand Up @@ -70,7 +70,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
int top_offset = M_ * N_;
CUDA_CHECK(cudaMemset(weight_diff, 0,
sizeof(Dtype) * this->blobs_[0]->count()));
for (int n = 0; n < NUM_; ++n) {
for (int n = 0; n < top[0]->num(); ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
Expand Down
3 changes: 3 additions & 0 deletions src/caffe/layers/dropout_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// Copyright 2013 Yangqing Jia

#include <vector>

#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {
using std::vector;

template <typename Dtype>
void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void FlattenLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
(*top)[0]->Reshape(bottom[0]->num(), channels_out, 1, 1);
count_ = bottom[0]->num() * channels_out;
CHECK_EQ(count_, bottom[0]->count());
CHECK_EQ(count_, (*top)[0]->count());
CHECK_LE(count_, (*top)[0]->count());
}

template <typename Dtype>
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/flatten_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ void FlattenLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
caffe_gpu_copy(count_, bottom_data, top_data);
caffe_gpu_copy(bottom[0]->count(), bottom_data, top_data);
}

template <typename Dtype>
Dtype FlattenLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
caffe_gpu_copy(count_, top_diff, bottom_diff);
caffe_gpu_copy(top[0]->count(), top_diff, bottom_diff);
return Dtype(0.);
}

Expand Down
25 changes: 14 additions & 11 deletions src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, bottom[0]->num(), N_, K_,
(Dtype)1., bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
caffe_cpu_gemm<Dtype>(
CblasNoTrans, CblasNoTrans, bottom[0]->num(), N_, 1, (Dtype)1.,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
}
Expand All @@ -82,19 +83,21 @@ Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
// Gradient with respect to weight
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, top[0]->num(),
(Dtype)1., top_diff, bottom_data, (Dtype)0.,
this->blobs_[0]->mutable_cpu_diff());
if (biasterm_) {
// Gradient with respect to bias
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
this->blobs_[1]->mutable_cpu_diff());
caffe_cpu_gemv<Dtype>(
CblasTrans, top[0]->num(), N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
(Dtype)0., this->blobs_[1]->mutable_cpu_diff());
}
if (propagate_down) {
// Gradient with respect to bottom data
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_cpu_diff());
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, top[0]->num(), K_, N_,
(Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
(Dtype)0., (*bottom)[0]->mutable_cpu_diff());
}
return Dtype(0);
}
Expand Down
Loading