diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index b2a5c2f697d..ef10aea53f0 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -20,9 +20,23 @@ class Blob { public: Blob() : data_(), diff_(), num_(0), channels_(0), height_(0), width_(0), - count_(0) {} + count_(0), capacity_(0) {} explicit Blob(const int num, const int channels, const int height, const int width); + /** + * @brief Change the dimensions of the blob, allocating new memory if + * necessary. + * + * This function can be called both to create an initial allocation + * of memory, and to adjust the dimensions of a top blob during Layer::Reshape + * or Layer::Forward. When changing the size of blob, memory will only be + * reallocated if sufficient memory does not already exist, and excess memory + * will never be freed. + * + * Note that reshaping an input blob and immediately calling Net::Backward is + * an error; either Net::Forward or Net::Reshape need to be called to + * propagate the new input shape to higher layers. + */ void Reshape(const int num, const int channels, const int height, const int width); void ReshapeLike(const Blob& other); @@ -120,6 +134,7 @@ class Blob { int height_; int width_; int count_; + int capacity_; DISABLE_COPY_AND_ASSIGN(Blob); }; // class Blob diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 190b5c24b05..1f945ca34e9 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -40,6 +40,8 @@ class ArgMaxLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_ARGMAX; @@ -81,6 +83,8 @@ class ConcatLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_CONCAT; @@ -159,6 +163,8 @@ class EltwiseLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_ELTWISE; @@ -178,7 +184,7 @@ class EltwiseLayer : public Layer { EltwiseParameter_EltwiseOp op_; vector coeffs_; - shared_ptr > max_idx_; + Blob max_idx_; bool stable_prod_grad_; }; @@ -198,7 +204,7 @@ class FlattenLayer : public Layer { public: explicit FlattenLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -251,6 +257,8 @@ class InnerProductLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_INNER_PRODUCT; @@ -285,7 +293,7 @@ class MVNLayer : public Layer { public: explicit MVNLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -319,7 +327,7 @@ class SilenceLayer : public Layer { public: explicit SilenceLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top) {} virtual inline LayerParameter_LayerType type() const { @@ -351,7 +359,7 @@ class SoftmaxLayer : public Layer { public: explicit SoftmaxLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -388,6 +396,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer { : SoftmaxLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNSoftmaxLayer(); protected: @@ -413,7 +423,7 @@ class SplitLayer : public Layer { public: explicit SplitLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -448,6 +458,8 @@ class SliceLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_SLICE; diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 15158029436..8e2637b0658 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -40,6 +40,9 @@ class BaseDataLayer : public Layer { vector*>* top); virtual void DataLayerSetUp(const vector*>& bottom, vector*>* top) {} + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + vector*>* top) {} virtual void Backward_cpu(const vector*>& top, const vector& propagate_down, vector*>* bottom) {} @@ -134,6 +137,9 @@ class DummyDataLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + vector*>* top) {} virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_DUMMY_DATA; @@ -166,6 +172,9 @@ class HDF5DataLayer : public Layer { virtual ~HDF5DataLayer(); virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + vector*>* top) {} virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_HDF5_DATA; @@ -204,6 +213,9 @@ class HDF5OutputLayer : public Layer { virtual ~HDF5OutputLayer(); virtual void LayerSetUp(const vector*>& bottom, vector*>* top) {} + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + vector*>* top) {} virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_HDF5_OUTPUT; diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 59e6ccf4ef9..e160075b939 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -48,35 +48,54 @@ class Layer { * * @param bottom the preshaped input blobs * @param top - * the allocated but unshaped output blobs, to be shaped by LayerSetUp + * the allocated but unshaped output blobs, to be shaped by Reshape * * Checks that the number of bottom and top blobs is correct. - * Calls LayerSetUp to do special layer setup for individual layer types. + * Calls LayerSetUp to do special layer setup for individual layer types, + * followed by Reshape to set up sizes of top blobs and internal buffers. * Sets up the loss weight multiplier blobs for any non-zero loss weights. * This method may not be overridden. */ void SetUp(const vector*>& bottom, vector*>* top) { CheckBlobCounts(bottom, *top); LayerSetUp(bottom, top); + Reshape(bottom, top); SetLossWeights(top); } /** - * @brief Does layer-specific setup: your layer should implement this. + * @brief Does layer-specific setup: your layer should implement this function + * as well as Reshape. * * @param bottom * the preshaped input blobs, whose data fields store the input data for * this layer * @param top - * the allocated but unshaped output blobs, to be initialized by LayerSetUp + * the allocated but unshaped output blobs * - * This method should be used to do layer-specific setup. At a minimum, this - * includes reshaping the empty top blobs to the shape as dictated by the - * shapes of the bottom blobs and any relevant parameters from the - * layer_param_. + * This method should do one-time layer specific setup. This includes reading + * and processing relevent parameters from the layer_param_. + * Setting up the shapes of top blobs and internal buffers should be done in + * Reshape, which will be called before the forward pass to + * adjust the top blob sizes. */ virtual void LayerSetUp(const vector*>& bottom, - vector*>* top) { NOT_IMPLEMENTED; } + vector*>* top) {} + + /** + * @brief Adjust the shapes of top blobs and internal buffers to accomodate + * the shapes of the bottom blobs. + * + * @param bottom the input blobs, with the requested input shapes + * @param top the top blobs, which should be reshaped as needed + * + * This method should reshape top blobs as needed according to the shapes + * of the bottom (input) blobs, as well as reshaping any internal buffers + * and making any other necessary adjustments so that the layer can + * accomodate the bottom blobs. + */ + virtual void Reshape(const vector*>& bottom, + vector*>* top) = 0; /** * @brief Given the bottom blobs, compute the top blobs and the loss. diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index a29c445d51e..b95e919c25f 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -34,6 +34,8 @@ class AccuracyLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_ACCURACY; @@ -97,6 +99,8 @@ class LossLayer : public Layer { : Layer(param) {} virtual void LayerSetUp( const vector*>& bottom, vector*>* top); + virtual void Reshape( + const vector*>& bottom, vector*>* top); virtual inline int ExactNumBottomBlobs() const { return 2; } @@ -148,7 +152,7 @@ class EuclideanLossLayer : public LossLayer { public: explicit EuclideanLossLayer(const LayerParameter& param) : LossLayer(param), diff_() {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -339,6 +343,8 @@ class InfogainLossLayer : public LossLayer { : LossLayer(param), infogain_() {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); // InfogainLossLayer takes 2-3 bottom Blobs; if there are 3 the third should // be the infogain matrix. (Otherwise the infogain matrix is loaded from a @@ -428,7 +434,7 @@ class MultinomialLogisticLossLayer : public LossLayer { public: explicit MultinomialLogisticLossLayer(const LayerParameter& param) : LossLayer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -510,6 +516,8 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { sigmoid_output_(new Blob()) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_SIGMOID_CROSS_ENTROPY_LOSS; @@ -606,6 +614,8 @@ class SoftmaxWithLossLayer : public LossLayer { softmax_layer_(new SoftmaxLayer(param)) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_SOFTMAX_LOSS; diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 3dcd0d5d598..1d06dc45533 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -67,6 +67,14 @@ class Net { void BackwardFrom(int start); void BackwardTo(int end); + /** + * @brief Reshape all layers from bottom to top. + * + * This is useful to propagate changes to layer sizes without running + * a forward pass, e.g. to compute output feature size. + */ + void Reshape(); + Dtype ForwardBackward(const vector* > & bottom) { Dtype loss; Forward(bottom, &loss); diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index 36acf96e5af..0968a2007dc 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -26,7 +26,7 @@ class NeuronLayer : public Layer { public: explicit NeuronLayer(const LayerParameter& param) : Layer(param) {} - virtual void LayerSetUp(const vector*>& bottom, + virtual void Reshape(const vector*>& bottom, vector*>* top); virtual inline LayerParameter_LayerType type() const { @@ -170,6 +170,8 @@ class DropoutLayer : public NeuronLayer { : NeuronLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_DROPOUT; @@ -367,6 +369,8 @@ class CuDNNReLULayer : public ReLULayer { : ReLULayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNReLULayer(); protected: @@ -449,6 +453,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer { : SigmoidLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNSigmoidLayer(); protected: @@ -533,6 +539,8 @@ class CuDNNTanHLayer : public TanHLayer { : TanHLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNTanHLayer(); protected: diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index e7ddea73e6b..aca5bd713fc 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -56,22 +56,26 @@ template<> class dataType { }; template -inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) { + CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc)); +} + +template +inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, int n, int c, int h, int w, int stride_n, int stride_c, int stride_h, int stride_w) { - CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc)); CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type, n, c, h, w, stride_n, stride_c, stride_h, stride_w)); } template -inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc, +inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc, int n, int c, int h, int w) { const int stride_w = 1; const int stride_h = w * stride_w; const int stride_c = h * stride_h; const int stride_n = c * stride_c; - createTensor4dDesc(desc, n, c, h, w, + setTensor4dDesc(desc, n, c, h, w, stride_n, stride_c, stride_h, stride_w); } @@ -84,10 +88,14 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc, } template -inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv, +inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) { + CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv)); +} + +template +inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter, int pad_h, int pad_w, int stride_h, int stride_w) { - CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv)); CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); } diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 9c8656f1f2b..1e7f3fcb297 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -67,6 +67,8 @@ class ConvolutionLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_CONVOLUTION; @@ -131,6 +133,8 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { : ConvolutionLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNConvolutionLayer(); protected: @@ -163,6 +167,8 @@ class Im2colLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_IM2COL; @@ -203,6 +209,8 @@ class LRNLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_LRN; @@ -278,6 +286,8 @@ class PoolingLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_POOLING; @@ -323,6 +333,8 @@ class CuDNNPoolingLayer : public PoolingLayer { : PoolingLayer(param) {} virtual void LayerSetUp(const vector*>& bottom, vector*>* top); + virtual void Reshape(const vector*>& bottom, + vector*>* top); virtual ~CuDNNPoolingLayer(); protected: diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 39e5d09077c..5a81a42329b 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -149,6 +149,7 @@ BOOST_PYTHON_MODULE(_caffe) { .def(bp::init()) .def("_forward", &PyNet::Forward) .def("_backward", &PyNet::Backward) + .def("reshape", &PyNet::Reshape) .def("set_mode_cpu", &PyNet::set_mode_cpu) .def("set_mode_gpu", &PyNet::set_mode_gpu) .def("set_phase_train", &PyNet::set_phase_train) diff --git a/python/caffe/_caffe.hpp b/python/caffe/_caffe.hpp index 5884900fcfe..ba04d276351 100644 --- a/python/caffe/_caffe.hpp +++ b/python/caffe/_caffe.hpp @@ -98,6 +98,7 @@ class PyNet { void Forward(int start, int end) { net_->ForwardFromTo(start, end); } void Backward(int start, int end) { net_->BackwardFromTo(start, end); } + void Reshape() { net_->Reshape(); } void set_input_arrays(bp::object data_obj, bp::object labels_obj); diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 9fd1232a9ad..cfffc379eb1 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -17,12 +17,10 @@ void Blob::Reshape(const int num, const int channels, const int height, height_ = height; width_ = width; count_ = num_ * channels_ * height_ * width_; - if (count_) { - data_.reset(new SyncedMemory(count_ * sizeof(Dtype))); - diff_.reset(new SyncedMemory(count_ * sizeof(Dtype))); - } else { - data_.reset(reinterpret_cast(NULL)); - diff_.reset(reinterpret_cast(NULL)); + if (count_ > capacity_) { + capacity_ = count_; + data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); + diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); } } @@ -33,7 +31,9 @@ void Blob::ReshapeLike(const Blob& other) { template Blob::Blob(const int num, const int channels, const int height, - const int width) { + const int width) + // capacity_ must be initialized before calling Reshape + : capacity_(0) { Reshape(num, channels, height, width); } diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp index 062e927183b..3e69bc84faa 100644 --- a/src/caffe/layers/accuracy_layer.cpp +++ b/src/caffe/layers/accuracy_layer.cpp @@ -14,6 +14,11 @@ template void AccuracyLayer::LayerSetUp( const vector*>& bottom, vector*>* top) { top_k_ = this->layer_param_.accuracy_param().top_k(); +} + +template +void AccuracyLayer::Reshape( + const vector*>& bottom, vector*>* top) { CHECK_EQ(bottom[0]->num(), bottom[1]->num()) << "The data and label should have the same number."; CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index 4b67f24cb9a..0d1a107257b 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -16,6 +16,11 @@ void ArgMaxLayer::LayerSetUp(const vector*>& bottom, CHECK_GE(top_k_, 1) << " top k must not be less than 1."; CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) << "top_k must be less than or equal to the number of classes."; +} + +template +void ArgMaxLayer::Reshape(const vector*>& bottom, + vector*>* top) { if (out_max_val_) { // Produces max_ind and max_val (*top)[0]->Reshape(bottom[0]->num(), 2, top_k_, 1); diff --git a/src/caffe/layers/concat_layer.cpp b/src/caffe/layers/concat_layer.cpp index 73d28b17850..10a11f1bb7c 100644 --- a/src/caffe/layers/concat_layer.cpp +++ b/src/caffe/layers/concat_layer.cpp @@ -14,7 +14,11 @@ void ConcatLayer::LayerSetUp(const vector*>& bottom, "concat_dim should be >= 0"; CHECK_LE(concat_dim_, 1) << "For now concat_dim <=1, it can only concat num and channels"; +} +template +void ConcatLayer::Reshape(const vector*>& bottom, + vector*>* top) { // Initialize with the first blob. count_ = bottom[0]->count(); num_ = bottom[0]->num(); diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index 769dfa671f6..58918fd4baf 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -47,47 +47,18 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, stride_h_ = conv_param.stride_h(); stride_w_ = conv_param.stride_w(); } - num_ = bottom[0]->num(); + // Configure output channels and groups. channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - // TODO: generalize to handle inputs of different shapes. - for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { - CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; - CHECK_EQ(channels_, bottom[bottom_id]->channels()) - << "Inputs must have same channels."; - CHECK_EQ(height_, bottom[bottom_id]->height()) - << "Inputs must have same height."; - CHECK_EQ(width_, bottom[bottom_id]->width()) - << "Inputs must have same width."; - } - // Configure output channels, groups, and spatial dimensions. num_output_ = this->layer_param_.convolution_param().num_output(); CHECK_GT(num_output_, 0); group_ = this->layer_param_.convolution_param().group(); CHECK_EQ(channels_ % group_, 0); CHECK_EQ(num_output_ % group_, 0) << "Number of output should be multiples of group."; - height_out_ = - (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; - width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; - for (int top_id = 0; top_id < top->size(); ++top_id) { - (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_); - } - // Prepare the matrix multiplication computation. - // Each input will be convolved as a single GEMM. - M_ = num_output_ / group_; - K_ = channels_ * kernel_h_ * kernel_w_ / group_; - N_ = height_out_ * width_out_; - // The im2col result buffer holds one image at a time to avoid - // overly large memory usage. - col_buffer_.Reshape( - 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_); // Handle the parameters: weights and biases. // - blobs_[0] holds the filter weights // - blobs_[1] holds the biases (optional) bias_term_ = this->layer_param_.convolution_param().bias_term(); - // Check if we need to set up the weights. if (this->blobs_.size() > 0) { LOG(INFO) << "Skipping parameter initialization"; } else { @@ -112,16 +83,54 @@ void ConvolutionLayer::LayerSetUp(const vector*>& bottom, bias_filler->Fill(this->blobs_[1].get()); } } + // Propagate gradients to the parameters (as directed by backward pass). + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void ConvolutionLayer::Reshape(const vector*>& bottom, + vector*>* top) { + num_ = bottom[0]->num(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" + " convolution kernel."; + // TODO: generalize to handle inputs of different shapes. + for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { + CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; + CHECK_EQ(channels_, bottom[bottom_id]->channels()) + << "Inputs must have same channels."; + CHECK_EQ(height_, bottom[bottom_id]->height()) + << "Inputs must have same height."; + CHECK_EQ(width_, bottom[bottom_id]->width()) + << "Inputs must have same width."; + } + // Shape the tops. + height_out_ = + (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1; + width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1; + for (int top_id = 0; top_id < top->size(); ++top_id) { + (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + } + // Prepare the matrix multiplication computation. + // Each input will be convolved as a single GEMM. + M_ = num_output_ / group_; + K_ = channels_ * kernel_h_ * kernel_w_ / group_; + N_ = height_out_ * width_out_; + // The im2col result buffer will only hold one image at a time to avoid + // overly large memory usage. + col_buffer_.Reshape( + 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_); + for (int top_id = 0; top_id < top->size(); ++top_id) { + (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + } // Set up the all ones "bias multiplier" for adding biases by BLAS if (bias_term_) { bias_multiplier_.Reshape(1, 1, 1, N_); caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data()); } - // Propagate gradients to the parameters (as directed by backward pass). - this->param_propagate_down_.resize(this->blobs_.size(), true); } - template void ConvolutionLayer::Forward_cpu(const vector*>& bottom, vector*>* top) { diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index eaacddcb0cc..137bbab1976 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -21,7 +21,7 @@ template void CuDNNConvolutionLayer::LayerSetUp( const vector*>& bottom, vector*>* top) { ConvolutionLayer::LayerSetUp(bottom, top); - // Initialize CUDA streams and cuNN. + // Initialize CUDA streams and cuDNN. stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; @@ -32,10 +32,6 @@ void CuDNNConvolutionLayer::LayerSetUp( } // Set the indexing parameters. - bottom_offset_ = (this->channels_ / this->group_) - * this->height_ * this->width_; - top_offset_ = (this->num_output_ / this->group_) - * this->height_out_ * this->width_out_; weight_offset_ = (this->num_output_ / this->group_) * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_; bias_offset_ = (this->num_output_ / this->group_); @@ -48,33 +44,54 @@ void CuDNNConvolutionLayer::LayerSetUp( // Create tensor descriptor(s) for data and corresponding convolution(s). for (int i = 0; i < bottom.size(); i++) { cudnnTensor4dDescriptor_t bottom_desc; - cudnn::createTensor4dDesc(&bottom_desc, + cudnn::createTensor4dDesc(&bottom_desc); + bottom_descs_.push_back(bottom_desc); + cudnnTensor4dDescriptor_t top_desc; + cudnn::createTensor4dDesc(&top_desc); + top_descs_.push_back(top_desc); + cudnnConvolutionDescriptor_t conv_desc; + cudnn::createConvolutionDesc(&conv_desc); + conv_descs_.push_back(conv_desc); + } + + // Tensor descriptor for bias. + if (this->bias_term_) { + cudnn::createTensor4dDesc(&bias_desc_); + } +} + +template +void CuDNNConvolutionLayer::Reshape( + const vector*>& bottom, vector*>* top) { + ConvolutionLayer::Reshape(bottom, top); + bottom_offset_ = (this->channels_ / this->group_) + * this->height_ * this->width_; + top_offset_ = (this->num_output_ / this->group_) + * this->height_out_ * this->width_out_; + + for (int i = 0; i < bottom.size(); i++) { + cudnn::setTensor4dDesc(&bottom_descs_[i], this->num_, this->channels_ / this->group_, this->height_, this->width_, this->channels_ * this->height_ * this->width_, this->height_ * this->width_, this->width_, 1); - bottom_descs_.push_back(bottom_desc); - cudnnTensor4dDescriptor_t top_desc; - cudnn::createTensor4dDesc(&top_desc, + cudnn::setTensor4dDesc(&top_descs_[i], this->num_, this->num_output_ / this->group_, this->height_out_, this->width_out_, this->num_output_ * this->height_out_ * this->width_out_, this->height_out_ * this->width_out_, this->width_out_, 1); - top_descs_.push_back(top_desc); - cudnnConvolutionDescriptor_t conv_desc; - cudnn::createConvolutionDesc(&conv_desc, bottom_desc, + cudnn::setConvolutionDesc(&conv_descs_[i], bottom_descs_[i], filter_desc_, this->pad_h_, this->pad_w_, this->stride_h_, this->stride_w_); - conv_descs_.push_back(conv_desc); } // Tensor descriptor for bias. if (this->bias_term_) { - cudnn::createTensor4dDesc(&bias_desc_, + cudnn::setTensor4dDesc(&bias_desc_, 1, this->num_output_ / this->group_, 1, 1); } } diff --git a/src/caffe/layers/cudnn_pooling_layer.cpp b/src/caffe/layers/cudnn_pooling_layer.cpp index 23c52012f1e..5aea0dc886e 100644 --- a/src/caffe/layers/cudnn_pooling_layer.cpp +++ b/src/caffe/layers/cudnn_pooling_layer.cpp @@ -15,15 +15,23 @@ void CuDNNPoolingLayer::LayerSetUp(const vector*>& bottom, PoolingLayer::LayerSetUp(bottom, top); CUDNN_CHECK(cudnnCreate(&handle_)); - cudnn::createTensor4dDesc(&bottom_desc_, bottom[0]->num(), - this->channels_, this->height_, this->width_); - cudnn::createTensor4dDesc(&top_desc_, bottom[0]->num(), - this->channels_, this->pooled_height_, this->pooled_width_); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); cudnn::createPoolingDesc(&pooling_desc_, this->layer_param_.pooling_param().pool(), &mode_, this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_); } +template +void CuDNNPoolingLayer::Reshape(const vector*>& bottom, + vector*>* top) { + PoolingLayer::Reshape(bottom, top); + cudnn::setTensor4dDesc(&bottom_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + cudnn::setTensor4dDesc(&top_desc_, bottom[0]->num(), + this->channels_, this->pooled_height_, this->pooled_width_); +} + template CuDNNPoolingLayer::~CuDNNPoolingLayer() { cudnnDestroyTensor4dDescriptor(bottom_desc_); diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp index f8bf77f1844..083868f572f 100644 --- a/src/caffe/layers/cudnn_relu_layer.cpp +++ b/src/caffe/layers/cudnn_relu_layer.cpp @@ -13,12 +13,20 @@ void CuDNNReLULayer::LayerSetUp(const vector*>& bottom, ReLULayer::LayerSetUp(bottom, top); // initialize cuDNN CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNReLULayer::Reshape(const vector*>& bottom, + vector*>* top) { + ReLULayer::Reshape(bottom, top); const int N = bottom[0]->num(); const int K = bottom[0]->channels(); const int H = bottom[0]->height(); const int W = bottom[0]->width(); - cudnn::createTensor4dDesc(&bottom_desc_, N, K, H, W); - cudnn::createTensor4dDesc(&top_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); } template diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp index 488c7545dba..3fe800db6f4 100644 --- a/src/caffe/layers/cudnn_sigmoid_layer.cpp +++ b/src/caffe/layers/cudnn_sigmoid_layer.cpp @@ -13,12 +13,20 @@ void CuDNNSigmoidLayer::LayerSetUp(const vector*>& bottom, SigmoidLayer::LayerSetUp(bottom, top); // initialize cuDNN CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNSigmoidLayer::Reshape(const vector*>& bottom, + vector*>* top) { + SigmoidLayer::Reshape(bottom, top); const int N = bottom[0]->num(); const int K = bottom[0]->channels(); const int H = bottom[0]->height(); const int W = bottom[0]->width(); - cudnn::createTensor4dDesc(&bottom_desc_, N, K, H, W); - cudnn::createTensor4dDesc(&top_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); } template diff --git a/src/caffe/layers/cudnn_softmax_layer.cpp b/src/caffe/layers/cudnn_softmax_layer.cpp index 6dab2d6ac32..79ba5237ae3 100644 --- a/src/caffe/layers/cudnn_softmax_layer.cpp +++ b/src/caffe/layers/cudnn_softmax_layer.cpp @@ -17,12 +17,20 @@ void CuDNNSoftmaxLayer::LayerSetUp(const vector*>& bottom, SoftmaxLayer::LayerSetUp(bottom, top); // Initialize CUDNN. CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNSoftmaxLayer::Reshape(const vector*>& bottom, + vector*>* top) { + SoftmaxLayer::Reshape(bottom, top); int N = bottom[0]->num(); int K = bottom[0]->channels(); int H = bottom[0]->height(); int W = bottom[0]->width(); - cudnn::createTensor4dDesc(&bottom_desc_, N, K, H, W); - cudnn::createTensor4dDesc(&top_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); } template diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp index 32b6611e40b..7a5c06f6596 100644 --- a/src/caffe/layers/cudnn_tanh_layer.cpp +++ b/src/caffe/layers/cudnn_tanh_layer.cpp @@ -13,12 +13,20 @@ void CuDNNTanHLayer::LayerSetUp(const vector*>& bottom, TanHLayer::LayerSetUp(bottom, top); // initialize cuDNN CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNTanHLayer::Reshape(const vector*>& bottom, + vector*>* top) { + TanHLayer::Reshape(bottom, top); const int N = bottom[0]->num(); const int K = bottom[0]->channels(); const int H = bottom[0]->height(); const int W = bottom[0]->width(); - cudnn::createTensor4dDesc(&bottom_desc_, N, K, H, W); - cudnn::createTensor4dDesc(&top_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); } template diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index 52537d1aba9..47feb1d2543 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -14,9 +14,6 @@ template void DropoutLayer::LayerSetUp(const vector*>& bottom, vector*>* top) { NeuronLayer::LayerSetUp(bottom, top); - // Set up the cache for random number generation - rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); threshold_ = this->layer_param_.dropout_param().dropout_ratio(); DCHECK(threshold_ > 0.); DCHECK(threshold_ < 1.); @@ -24,6 +21,15 @@ void DropoutLayer::LayerSetUp(const vector*>& bottom, uint_thres_ = static_cast(UINT_MAX * threshold_); } +template +void DropoutLayer::Reshape(const vector*>& bottom, + vector*>* top) { + NeuronLayer::Reshape(bottom, top); + // Set up the cache for random number generation + rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + template void DropoutLayer::Forward_cpu(const vector*>& bottom, vector*>* top) { diff --git a/src/caffe/layers/eltwise_layer.cpp b/src/caffe/layers/eltwise_layer.cpp index 46034be4784..569560f97d3 100644 --- a/src/caffe/layers/eltwise_layer.cpp +++ b/src/caffe/layers/eltwise_layer.cpp @@ -17,6 +17,20 @@ void EltwiseLayer::LayerSetUp(const vector*>& bottom, == EltwiseParameter_EltwiseOp_PROD && this->layer_param().eltwise_param().coeff_size())) << "Eltwise layer only takes coefficients for summation."; + op_ = this->layer_param_.eltwise_param().operation(); + // Blob-wise coefficients for the elementwise operation. + coeffs_ = vector(bottom.size(), 1); + if (this->layer_param().eltwise_param().coeff_size()) { + for (int i = 0; i < bottom.size(); ++i) { + coeffs_[i] = this->layer_param().eltwise_param().coeff(i); + } + } + stable_prod_grad_ = this->layer_param_.eltwise_param().stable_prod_grad(); +} + +template +void EltwiseLayer::Reshape(const vector*>& bottom, + vector*>* top) { const int num = bottom[0]->num(); const int channels = bottom[0]->channels(); const int height = bottom[0]->height(); @@ -28,20 +42,10 @@ void EltwiseLayer::LayerSetUp(const vector*>& bottom, CHECK_EQ(width, bottom[i]->width()); } (*top)[0]->Reshape(num, channels, height, width); - op_ = this->layer_param_.eltwise_param().operation(); - // Blob-wise coefficients for the elementwise operation. - coeffs_ = vector(bottom.size(), 1); - if (this->layer_param().eltwise_param().coeff_size()) { - for (int i = 0; i < bottom.size(); ++i) { - coeffs_[i] = this->layer_param().eltwise_param().coeff(i); - } - } - stable_prod_grad_ = this->layer_param_.eltwise_param().stable_prod_grad(); // If max operation, we will initialize the vector index part. if (this->layer_param_.eltwise_param().operation() == EltwiseParameter_EltwiseOp_MAX && top->size() == 1) { - max_idx_.reset(new Blob(bottom[0]->num(), channels, - height, width)); + max_idx_.Reshape(bottom[0]->num(), channels, height, width); } } @@ -69,7 +73,7 @@ void EltwiseLayer::Forward_cpu( break; case EltwiseParameter_EltwiseOp_MAX: // Initialize - mask = max_idx_->mutable_cpu_data(); + mask = max_idx_.mutable_cpu_data(); caffe_set(count, -1, mask); caffe_set(count, Dtype(-FLT_MAX), top_data); // bottom 0 & 1 @@ -138,7 +142,7 @@ void EltwiseLayer::Backward_cpu(const vector*>& top, } break; case EltwiseParameter_EltwiseOp_MAX: - mask = max_idx_->cpu_data(); + mask = max_idx_.cpu_data(); for (int index = 0; index < count; ++index) { Dtype gradient = 0; if (mask[index] == i) { diff --git a/src/caffe/layers/eltwise_layer.cu b/src/caffe/layers/eltwise_layer.cu index c0d47fd413b..16cb6cc77e3 100644 --- a/src/caffe/layers/eltwise_layer.cu +++ b/src/caffe/layers/eltwise_layer.cu @@ -53,7 +53,7 @@ void EltwiseLayer::Forward_gpu(const vector*>& bottom, } break; case EltwiseParameter_EltwiseOp_MAX: - mask = max_idx_->mutable_gpu_data(); + mask = max_idx_.mutable_gpu_data(); // NOLINT_NEXT_LINE(whitespace/operators) MaxForward <<>>( count, bottom[0]->gpu_data(), bottom[1]->gpu_data(), 0, top_data, mask); @@ -118,7 +118,7 @@ void EltwiseLayer::Backward_gpu(const vector*>& top, } break; case EltwiseParameter_EltwiseOp_MAX: - mask = max_idx_->gpu_data(); + mask = max_idx_.gpu_data(); MaxBackward // NOLINT_NEXT_LINE(whitespace/operators) <<>>( count, top_diff, i, mask, bottom_diff); diff --git a/src/caffe/layers/euclidean_loss_layer.cpp b/src/caffe/layers/euclidean_loss_layer.cpp index be83601f5a1..1b4a13d2ddc 100644 --- a/src/caffe/layers/euclidean_loss_layer.cpp +++ b/src/caffe/layers/euclidean_loss_layer.cpp @@ -8,9 +8,9 @@ namespace caffe { template -void EuclideanLossLayer::LayerSetUp( +void EuclideanLossLayer::Reshape( const vector*>& bottom, vector*>* top) { - LossLayer::LayerSetUp(bottom, top); + LossLayer::Reshape(bottom, top); CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); CHECK_EQ(bottom[0]->height(), bottom[1]->height()); CHECK_EQ(bottom[0]->width(), bottom[1]->width()); diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp index 8c1fc74e159..65310cd1669 100644 --- a/src/caffe/layers/flatten_layer.cpp +++ b/src/caffe/layers/flatten_layer.cpp @@ -7,7 +7,7 @@ namespace caffe { template -void FlattenLayer::LayerSetUp(const vector*>& bottom, +void FlattenLayer::Reshape(const vector*>& bottom, vector*>* top) { int channels_out = bottom[0]->channels() * bottom[0]->height() * bottom[0]->width(); diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 02f33f1cff4..870d5a9bde3 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -45,6 +45,11 @@ void Im2colLayer::LayerSetUp(const vector*>& bottom, stride_h_ = conv_param.stride_h(); stride_w_ = conv_param.stride_w(); } +} + +template +void Im2colLayer::Reshape(const vector*>& bottom, + vector*>* top) { channels_ = bottom[0]->channels(); height_ = bottom[0]->height(); width_ = bottom[0]->width(); diff --git a/src/caffe/layers/infogain_loss_layer.cpp b/src/caffe/layers/infogain_loss_layer.cpp index 91dd89240e4..894cb69811a 100644 --- a/src/caffe/layers/infogain_loss_layer.cpp +++ b/src/caffe/layers/infogain_loss_layer.cpp @@ -14,10 +14,6 @@ template void InfogainLossLayer::LayerSetUp( const vector*>& bottom, vector*>* top) { LossLayer::LayerSetUp(bottom, top); - CHECK_EQ(bottom[1]->channels(), 1); - CHECK_EQ(bottom[1]->height(), 1); - CHECK_EQ(bottom[1]->width(), 1); - Blob* infogain = NULL; if (bottom.size() < 3) { CHECK(this->layer_param_.infogain_loss_param().has_source()) << "Infogain matrix source must be specified."; @@ -25,10 +21,22 @@ void InfogainLossLayer::LayerSetUp( ReadProtoFromBinaryFile( this->layer_param_.infogain_loss_param().source(), &blob_proto); infogain_.FromProto(blob_proto); + } +} + +template +void InfogainLossLayer::Reshape( + const vector*>& bottom, vector*>* top) { + LossLayer::Reshape(bottom, top); + Blob* infogain = NULL; + if (bottom.size() < 3) { infogain = &infogain_; } else { infogain = bottom[2]; } + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); const int num = bottom[0]->num(); const int dim = bottom[0]->count() / num; CHECK_EQ(infogain->num(), 1); diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 3ba0e1f29f6..ecd05a030db 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -14,11 +14,8 @@ void InnerProductLayer::LayerSetUp(const vector*>& bottom, vector*>* top) { const int num_output = this->layer_param_.inner_product_param().num_output(); bias_term_ = this->layer_param_.inner_product_param().bias_term(); - // Figure out the dimensions - M_ = bottom[0]->num(); - K_ = bottom[0]->count() / bottom[0]->num(); N_ = num_output; - (*top)[0]->Reshape(bottom[0]->num(), num_output, 1, 1); + K_ = bottom[0]->count() / bottom[0]->num(); // Check if we need to set up the weights if (this->blobs_.size() > 0) { LOG(INFO) << "Skipping parameter initialization"; @@ -42,12 +39,22 @@ void InnerProductLayer::LayerSetUp(const vector*>& bottom, bias_filler->Fill(this->blobs_[1].get()); } } // parameter initialization - // Setting up the bias multiplier + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void InnerProductLayer::Reshape(const vector*>& bottom, + vector*>* top) { + // Figure out the dimensions + M_ = bottom[0]->num(); + CHECK_EQ(bottom[0]->count() / bottom[0]->num(), K_) << "Input size " + "incompatible with inner product parameters."; + (*top)[0]->Reshape(bottom[0]->num(), N_, 1, 1); + // Set up the bias multiplier if (bias_term_) { bias_multiplier_.Reshape(1, 1, 1, M_); caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); } - this->param_propagate_down_.resize(this->blobs_.size(), true); } template diff --git a/src/caffe/layers/loss_layer.cpp b/src/caffe/layers/loss_layer.cpp index 89d8c91e342..9eb9dbd5c5b 100644 --- a/src/caffe/layers/loss_layer.cpp +++ b/src/caffe/layers/loss_layer.cpp @@ -13,15 +13,20 @@ namespace caffe { template void LossLayer::LayerSetUp( const vector*>& bottom, vector*>* top) { - CHECK_EQ(bottom[0]->num(), bottom[1]->num()) - << "The data and label should have the same number."; - (*top)[0]->Reshape(1, 1, 1, 1); // LossLayers have a non-zero (1) loss by default. if (this->layer_param_.loss_weight_size() == 0) { this->layer_param_.add_loss_weight(Dtype(1)); } } +template +void LossLayer::Reshape( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom[0]->num(), bottom[1]->num()) + << "The data and label should have the same number."; + (*top)[0]->Reshape(1, 1, 1, 1); +} + INSTANTIATE_CLASS(LossLayer); } // namespace caffe diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index e81a32ba84f..d9e41e9c137 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -9,88 +9,81 @@ namespace caffe { template void LRNLayer::LayerSetUp(const vector*>& bottom, vector*>* top) { - num_ = bottom[0]->num(); - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); size_ = this->layer_param_.lrn_param().local_size(); + CHECK_EQ(size_ % 2, 1) << "LRN only supports odd values for local_size"; pre_pad_ = (size_ - 1) / 2; alpha_ = this->layer_param_.lrn_param().alpha(); beta_ = this->layer_param_.lrn_param().beta(); + if (this->layer_param_.lrn_param().norm_region() == + LRNParameter_NormRegion_WITHIN_CHANNEL) { + // Set up split_layer_ to use inputs in the numerator and denominator. + split_top_vec_.clear(); + split_top_vec_.push_back(&product_input_); + split_top_vec_.push_back(&square_input_); + LayerParameter split_param; + split_layer_.reset(new SplitLayer(split_param)); + split_layer_->SetUp(bottom, &split_top_vec_); + // Set up square_layer_ to square the inputs. + square_bottom_vec_.clear(); + square_top_vec_.clear(); + square_bottom_vec_.push_back(&square_input_); + square_top_vec_.push_back(&square_output_); + LayerParameter square_param; + square_param.mutable_power_param()->set_power(Dtype(2)); + square_layer_.reset(new PowerLayer(square_param)); + square_layer_->SetUp(square_bottom_vec_, &square_top_vec_); + // Set up pool_layer_ to sum over square neighborhoods of the input. + pool_top_vec_.clear(); + pool_top_vec_.push_back(&pool_output_); + LayerParameter pool_param; + pool_param.mutable_pooling_param()->set_pool( + PoolingParameter_PoolMethod_AVE); + pool_param.mutable_pooling_param()->set_pad(pre_pad_); + pool_param.mutable_pooling_param()->set_kernel_size(size_); + pool_layer_.reset(new PoolingLayer(pool_param)); + pool_layer_->SetUp(square_top_vec_, &pool_top_vec_); + // Set up power_layer_ to compute (1 + alpha_/N^2 s)^-beta_, where s is + // the sum of a squared neighborhood (the output of pool_layer_). + power_top_vec_.clear(); + power_top_vec_.push_back(&power_output_); + LayerParameter power_param; + power_param.mutable_power_param()->set_power(-beta_); + power_param.mutable_power_param()->set_scale(alpha_); + power_param.mutable_power_param()->set_shift(Dtype(1)); + power_layer_.reset(new PowerLayer(power_param)); + power_layer_->SetUp(pool_top_vec_, &power_top_vec_); + // Set up a product_layer_ to compute outputs by multiplying inputs by the + // inverse demoninator computed by the power layer. + product_bottom_vec_.clear(); + product_bottom_vec_.push_back(&product_input_); + product_bottom_vec_.push_back(&power_output_); + LayerParameter product_param; + EltwiseParameter* eltwise_param = product_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + product_layer_.reset(new EltwiseLayer(product_param)); + product_layer_->SetUp(product_bottom_vec_, top); + } +} + +template +void LRNLayer::Reshape(const vector*>& bottom, + vector*>* top) { + num_ = bottom[0]->num(); + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); switch (this->layer_param_.lrn_param().norm_region()) { case LRNParameter_NormRegion_ACROSS_CHANNELS: (*top)[0]->Reshape(num_, channels_, height_, width_); scale_.Reshape(num_, channels_, height_, width_); break; case LRNParameter_NormRegion_WITHIN_CHANNEL: - { - // Set up split_layer_ to use inputs in the numerator and denominator. - split_top_vec_.clear(); - split_top_vec_.push_back(&product_input_); - split_top_vec_.push_back(&square_input_); - LayerParameter split_param; - split_layer_.reset(new SplitLayer(split_param)); - split_layer_->SetUp(bottom, &split_top_vec_); - // Set up square_layer_ to square the inputs. - square_input_.Reshape(num_, channels_, height_, width_); - square_bottom_vec_.clear(); - square_top_vec_.clear(); - square_bottom_vec_.push_back(&square_input_); - square_top_vec_.push_back(&square_output_); - LayerParameter square_param; - square_param.mutable_power_param()->set_power(Dtype(2)); - square_layer_.reset(new PowerLayer(square_param)); - square_layer_->SetUp(square_bottom_vec_, &square_top_vec_); - CHECK_EQ(square_output_.num(), num_); - CHECK_EQ(square_output_.channels(), channels_); - CHECK_EQ(square_output_.height(), height_); - CHECK_EQ(square_output_.width(), width_); - // Set up pool_layer_ to sum over square neighborhoods of the input. - pool_top_vec_.clear(); - pool_top_vec_.push_back(&pool_output_); - LayerParameter pool_param; - pool_param.mutable_pooling_param()->set_pool( - PoolingParameter_PoolMethod_AVE); - pool_param.mutable_pooling_param()->set_pad(pre_pad_); - pool_param.mutable_pooling_param()->set_kernel_size(size_); - pool_layer_.reset(new PoolingLayer(pool_param)); - pool_layer_->SetUp(square_top_vec_, &pool_top_vec_); - CHECK_EQ(pool_output_.num(), num_); - CHECK_EQ(pool_output_.channels(), channels_); - CHECK_EQ(pool_output_.height(), height_); - CHECK_EQ(pool_output_.width(), width_); - // Set up power_layer_ to compute (1 + alpha_/N^2 s)^-beta_, where s is - // the sum of a squared neighborhood (the output of pool_layer_). - power_top_vec_.clear(); - power_top_vec_.push_back(&power_output_); - LayerParameter power_param; - power_param.mutable_power_param()->set_power(-beta_); - power_param.mutable_power_param()->set_scale(alpha_); - power_param.mutable_power_param()->set_shift(Dtype(1)); - power_layer_.reset(new PowerLayer(power_param)); - power_layer_->SetUp(pool_top_vec_, &power_top_vec_); - CHECK_EQ(power_output_.num(), num_); - CHECK_EQ(power_output_.channels(), channels_); - CHECK_EQ(power_output_.height(), height_); - CHECK_EQ(power_output_.width(), width_); - // Set up a product_layer_ to compute outputs by multiplying inputs by the - // inverse demoninator computed by the power layer. - product_bottom_vec_.clear(); - product_bottom_vec_.push_back(&product_input_); - product_bottom_vec_.push_back(&power_output_); - LayerParameter product_param; - EltwiseParameter* eltwise_param = product_param.mutable_eltwise_param(); - eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); - product_layer_.reset(new EltwiseLayer(product_param)); - product_layer_->SetUp(product_bottom_vec_, top); - CHECK_EQ((*top)[0]->num(), num_); - CHECK_EQ((*top)[0]->channels(), channels_); - CHECK_EQ((*top)[0]->height(), height_); - CHECK_EQ((*top)[0]->width(), width_); - } + split_layer_->Reshape(bottom, &split_top_vec_); + square_layer_->Reshape(square_bottom_vec_, &square_top_vec_); + pool_layer_->Reshape(square_top_vec_, &pool_top_vec_); + power_layer_->Reshape(pool_top_vec_, &power_top_vec_); + product_layer_->Reshape(product_bottom_vec_, top); break; - default: - LOG(FATAL) << "Unknown normalization region."; } } diff --git a/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/src/caffe/layers/multinomial_logistic_loss_layer.cpp index cf96bfe73ed..c0fe1966a4d 100644 --- a/src/caffe/layers/multinomial_logistic_loss_layer.cpp +++ b/src/caffe/layers/multinomial_logistic_loss_layer.cpp @@ -11,9 +11,9 @@ namespace caffe { template -void MultinomialLogisticLossLayer::LayerSetUp( +void MultinomialLogisticLossLayer::Reshape( const vector*>& bottom, vector*>* top) { - LossLayer::LayerSetUp(bottom, top); + LossLayer::Reshape(bottom, top); CHECK_EQ(bottom[1]->channels(), 1); CHECK_EQ(bottom[1]->height(), 1); CHECK_EQ(bottom[1]->width(), 1); diff --git a/src/caffe/layers/mvn_layer.cpp b/src/caffe/layers/mvn_layer.cpp index 4d90702fd55..6a57b3ea7fc 100644 --- a/src/caffe/layers/mvn_layer.cpp +++ b/src/caffe/layers/mvn_layer.cpp @@ -8,7 +8,7 @@ namespace caffe { template -void MVNLayer::LayerSetUp(const vector*>& bottom, +void MVNLayer::Reshape(const vector*>& bottom, vector*>* top) { (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[0]->height(), bottom[0]->width()); diff --git a/src/caffe/layers/neuron_layer.cpp b/src/caffe/layers/neuron_layer.cpp index eff7948a616..c28e36ea23b 100644 --- a/src/caffe/layers/neuron_layer.cpp +++ b/src/caffe/layers/neuron_layer.cpp @@ -6,13 +6,9 @@ namespace caffe { template -void NeuronLayer::LayerSetUp(const vector*>& bottom, +void NeuronLayer::Reshape(const vector*>& bottom, vector*>* top) { - // NeuronLayer allows in-place computations. If the computation is not - // in-place, we will need to initialize the top blob. - if ((*top)[0] != bottom[0]) { - (*top)[0]->ReshapeLike(*bottom[0]); - } + (*top)[0]->ReshapeLike(*bottom[0]); } INSTANTIATE_CLASS(NeuronLayer); diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index 9e77fa28a45..8e8ffad66a0 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -60,6 +60,11 @@ void PoolingLayer::LayerSetUp(const vector*>& bottom, CHECK_LT(pad_h_, kernel_h_); CHECK_LT(pad_w_, kernel_w_); } +} + +template +void PoolingLayer::Reshape(const vector*>& bottom, + vector*>* top) { channels_ = bottom[0]->channels(); height_ = bottom[0]->height(); width_ = bottom[0]->width(); diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index 6e440a820c4..6a48099ae8b 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -12,8 +12,6 @@ template void SigmoidCrossEntropyLossLayer::LayerSetUp( const vector*>& bottom, vector*>* top) { LossLayer::LayerSetUp(bottom, top); - CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << - "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count."; sigmoid_bottom_vec_.clear(); sigmoid_bottom_vec_.push_back(bottom[0]); sigmoid_top_vec_.clear(); @@ -21,6 +19,15 @@ void SigmoidCrossEntropyLossLayer::LayerSetUp( sigmoid_layer_->SetUp(sigmoid_bottom_vec_, &sigmoid_top_vec_); } +template +void SigmoidCrossEntropyLossLayer::Reshape( + const vector*>& bottom, vector*>* top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << + "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count."; + sigmoid_layer_->Reshape(sigmoid_bottom_vec_, &sigmoid_top_vec_); +} + template void SigmoidCrossEntropyLossLayer::Forward_cpu( const vector*>& bottom, vector*>* top) { diff --git a/src/caffe/layers/slice_layer.cpp b/src/caffe/layers/slice_layer.cpp index 9fa127528a0..ed679a9169e 100644 --- a/src/caffe/layers/slice_layer.cpp +++ b/src/caffe/layers/slice_layer.cpp @@ -18,6 +18,11 @@ void SliceLayer::LayerSetUp(const vector*>& bottom, std::copy(slice_param.slice_point().begin(), slice_param.slice_point().end(), std::back_inserter(slice_point_)); +} + +template +void SliceLayer::Reshape(const vector*>& bottom, + vector*>* top) { count_ = 0; num_ = bottom[0]->num(); channels_ = bottom[0]->channels(); @@ -50,7 +55,6 @@ void SliceLayer::LayerSetUp(const vector*>& bottom, count_ += (*top)[i]->count(); } } - } else { if (slice_dim_ == 0) { CHECK_EQ(num_ % top->size(), 0) diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp index 952db74a4b6..60668a3f8ce 100644 --- a/src/caffe/layers/softmax_layer.cpp +++ b/src/caffe/layers/softmax_layer.cpp @@ -8,7 +8,7 @@ namespace caffe { template -void SoftmaxLayer::LayerSetUp(const vector*>& bottom, +void SoftmaxLayer::Reshape(const vector*>& bottom, vector*>* top) { (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[0]->height(), bottom[0]->width()); diff --git a/src/caffe/layers/softmax_loss_layer.cpp b/src/caffe/layers/softmax_loss_layer.cpp index e07c8dc9df9..55392c37ca0 100644 --- a/src/caffe/layers/softmax_loss_layer.cpp +++ b/src/caffe/layers/softmax_loss_layer.cpp @@ -17,6 +17,13 @@ void SoftmaxWithLossLayer::LayerSetUp( softmax_top_vec_.clear(); softmax_top_vec_.push_back(&prob_); softmax_layer_->SetUp(softmax_bottom_vec_, &softmax_top_vec_); +} + +template +void SoftmaxWithLossLayer::Reshape( + const vector*>& bottom, vector*>* top) { + LossLayer::Reshape(bottom, top); + softmax_layer_->Reshape(softmax_bottom_vec_, &softmax_top_vec_); if (top->size() >= 2) { // softmax output (*top)[1]->ReshapeLike(*bottom[0]); diff --git a/src/caffe/layers/split_layer.cpp b/src/caffe/layers/split_layer.cpp index 3bed347ea6c..40d3600ff17 100644 --- a/src/caffe/layers/split_layer.cpp +++ b/src/caffe/layers/split_layer.cpp @@ -7,7 +7,7 @@ namespace caffe { template -void SplitLayer::LayerSetUp(const vector*>& bottom, +void SplitLayer::Reshape(const vector*>& bottom, vector*>* top) { count_ = bottom[0]->count(); for (int i = 0; i < top->size(); ++i) { diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index a3e7122c2c4..6f4a651fb10 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -504,6 +504,7 @@ Dtype Net::ForwardFromTo(int start, int end) { Dtype loss = 0; for (int i = start; i <= end; ++i) { // LOG(ERROR) << "Forwarding " << layer_names_[i]; + layers_[i]->Reshape(bottom_vecs_[i], &top_vecs_[i]); Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]); loss += layer_loss; if (debug_info_) { ForwardDebugInfo(i); } @@ -679,6 +680,13 @@ void Net::Backward() { BackwardFromTo(layers_.size() - 1, 0); } +template +void Net::Reshape() { + for (int i = 0; i < layers_.size(); ++i) { + layers_[i]->Reshape(bottom_vecs_[i], &top_vecs_[i]); + } +} + template void Net::CopyTrainedLayersFrom(const NetParameter& param) { int num_source_layers = param.layers_size(); diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 24f72aef3af..9b10d100afb 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -7,6 +7,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" +#include "caffe/filler.hpp" #include "caffe/net.hpp" #include "caffe/util/math_functions.hpp" @@ -533,6 +534,68 @@ class NetTest : public MultiDeviceTest { InitNetFromProtoString(proto); } + virtual void InitReshapableNet() { + const string& proto = + "name: 'ReshapableNetwork' " + "input: 'data' " + "input_dim: 1 " + "input_dim: 3 " + "input_dim: 100 " + "input_dim: 100 " + "layers: { " + " name: 'conv1' " + " type: CONVOLUTION " + " bottom: 'data' " + " top: 'conv1' " + " convolution_param { " + " num_output: 5 " + " kernel_size: 3 " + " stride: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0.2 " + " } " + " } " + "} " + "layers: { " + " name: 'relu1' " + " type: RELU " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers: { " + " name: 'pool1' " + " type: POOLING " + " bottom: 'conv1' " + " top: 'pool1' " + " pooling_param { " + " pool: MAX " + " kernel_size: 2 " + " stride: 2 " + " } " + "} " + "layers: { " + " name: 'norm1' " + " type: LRN " + " bottom: 'pool1' " + " top: 'norm1' " + " lrn_param { " + " local_size: 3 " + " } " + "} " + "layers: { " + " name: 'softmax' " + " type: SOFTMAX " + " bottom: 'norm1' " + " top: 'softmax' " + "} "; + InitNetFromProtoString(proto); + } + int seed_; shared_ptr > net_; }; @@ -2028,4 +2091,62 @@ TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) { this->RunFilterNetTest(input_proto_test, output_proto_test); } +TYPED_TEST(NetTest, TestReshape) { + typedef typename TypeParam::Dtype Dtype; + // We set up bottom blobs of two different sizes, switch between + // them, and check that forward and backward both run and the results + // are the same. + Caffe::set_random_seed(this->seed_); + Caffe::set_mode(Caffe::CPU); + FillerParameter filler_param; + filler_param.set_std(1); + GaussianFiller filler(filler_param); + Blob blob1(4, 3, 9, 11); + Blob blob2(2, 3, 12, 10); + filler.Fill(&blob1); + filler.Fill(&blob2); + + this->InitReshapableNet(); + Blob* input_blob = this->net_->input_blobs()[0]; + Blob* output_blob = this->net_->output_blobs()[0]; + input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), + blob1.width()); + caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + // call backward just to make sure it runs + this->net_->Backward(); + Blob output1(output_blob->num(), output_blob->channels(), + output_blob->height(), output_blob->width()); + caffe_copy(output1.count(), output_blob->cpu_data(), + output1.mutable_cpu_data()); + + input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), + blob2.width()); + caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + Blob output2(output_blob->num(), output_blob->channels(), + output_blob->height(), output_blob->width()); + caffe_copy(output2.count(), output_blob->cpu_data(), + output2.mutable_cpu_data()); + + input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), + blob1.width()); + caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + for (int i = 0; i < output1.count(); ++i) { + CHECK_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i)); + } + + input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), + blob2.width()); + caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + for (int i = 0; i < output2.count(); ++i) { + CHECK_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i)); + } +} + } // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index fa27fdf2118..c8c8c1a6b4c 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -235,6 +235,9 @@ int time() { const caffe::string& layername = layers[i]->layer_param().name(); timer.Start(); for (int j = 0; j < FLAGS_iterations; ++j) { + // Although Reshape should be essentially free, we include it here + // so that we will notice Reshape performance bugs. + layers[i]->Reshape(bottom_vecs[i], &top_vecs[i]); layers[i]->Forward(bottom_vecs[i], &top_vecs[i]); } LOG(INFO) << layername << "\tforward: " << timer.MilliSeconds() <<