diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index 2f6b8f80e68..75101462faf 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -63,6 +63,7 @@ class Blob { } const Dtype* cpu_data() const; + void set_cpu_data(Dtype* data); const Dtype* gpu_data() const; const Dtype* cpu_diff() const; const Dtype* gpu_diff() const; diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index e265c455c50..bed55c3806e 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -35,17 +35,21 @@ inline void CaffeFreeHost(void* ptr) { class SyncedMemory { public: SyncedMemory() - : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED) {} + : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), + own_cpu_data_(false) {} explicit SyncedMemory(size_t size) - : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED) {} + : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), + own_cpu_data_(false) {} ~SyncedMemory(); const void* cpu_data(); + void set_cpu_data(void* data); const void* gpu_data(); void* mutable_cpu_data(); void* mutable_gpu_data(); enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; SyncedHead head() { return head_; } size_t size() { return size_; } + private: void to_cpu(); void to_gpu(); @@ -53,6 +57,7 @@ class SyncedMemory { void* gpu_ptr_; size_t size_; SyncedHead head_; + bool own_cpu_data_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 817bce9f21e..4765398aa7b 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -624,6 +624,40 @@ class LRNLayer : public Layer { vector*> product_bottom_vec_; }; +template +class MemoryDataLayer : public Layer { + public: + explicit MemoryDataLayer(const LayerParameter& param) + : Layer(param) {} + virtual void SetUp(const vector*>& bottom, + vector*>* top); + // Reset should accept const pointers, but can't, because the memory + // will be given to Blob, which is mutable + void Reset(Dtype* data, Dtype* label, int n); + int datum_channels() { return datum_channels_; } + int datum_height() { return datum_height_; } + int datum_width() { return datum_width_; } + int batch_size() { return batch_size_; } + + protected: + virtual Dtype Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { return; } + virtual void Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { return; } + + Dtype* data_; + Dtype* labels_; + int datum_channels_; + int datum_height_; + int datum_width_; + int datum_size_; + int batch_size_; + int n_; + int pos_; +}; + template class MultinomialLogisticLossLayer : public Layer { public: diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index a443f897db9..08899d4dce6 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -158,6 +158,8 @@ struct CaffeNet { virtual ~CaffeNet() {} + // this function is mostly redundant with the one below, but should go away + // with new pycaffe inline void check_array_against_blob( PyArrayObject* arr, Blob* blob) { CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS); @@ -170,6 +172,29 @@ struct CaffeNet { CHECK_EQ(dims[3], blob->width()); } + // generate Python exceptions for badly shaped or discontiguous arrays + inline void check_contiguous_array(PyArrayObject* arr, string name, + int channels, int height, int width) { + if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) { + throw std::runtime_error(name + " must be C contiguous"); + } + if (PyArray_NDIM(arr) != 4) { + throw std::runtime_error(name + " must be 4-d"); + } + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + throw std::runtime_error(name + " must be float32"); + } + if (PyArray_DIMS(arr)[1] != channels) { + throw std::runtime_error(name + " has wrong number of channels"); + } + if (PyArray_DIMS(arr)[2] != height) { + throw std::runtime_error(name + " has wrong height"); + } + if (PyArray_DIMS(arr)[3] != width) { + throw std::runtime_error(name + " has wrong width"); + } + } + // The actual forward function. It takes in a python list of numpy arrays as // input and a python list of numpy arrays as output. The input and output // should all have correct shapes, are single-precisionabcdnt- and @@ -267,6 +292,41 @@ struct CaffeNet { net_->ForwardPrefilled(); } + void set_input_arrays(object data_obj, object labels_obj) { + // check that this network has an input MemoryDataLayer + shared_ptr > md_layer = + boost::dynamic_pointer_cast >(net_->layers()[0]); + if (!md_layer) { + throw std::runtime_error("set_input_arrays may only be called if the" + " first layer is a MemoryDataLayer"); + } + + // check that we were passed appropriately-sized contiguous memory + PyArrayObject* data_arr = + reinterpret_cast(data_obj.ptr()); + PyArrayObject* labels_arr = + reinterpret_cast(labels_obj.ptr()); + check_contiguous_array(data_arr, "data array", md_layer->datum_channels(), + md_layer->datum_height(), md_layer->datum_width()); + check_contiguous_array(labels_arr, "labels array", 1, 1, 1); + if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) { + throw std::runtime_error("data and labels must have the same first" + " dimension"); + } + if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) { + throw std::runtime_error("first dimensions of input arrays must be a" + " multiple of batch size"); + } + + // hold references + input_data_ = data_obj; + input_labels_ = labels_obj; + + md_layer->Reset(static_cast(PyArray_DATA(data_arr)), + static_cast(PyArray_DATA(labels_arr)), + PyArray_DIMS(data_arr)[0]); + } + // The caffe::Caffe utility functions. void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); } void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); } @@ -292,6 +352,9 @@ struct CaffeNet { // The pointer to the internal caffe::Net instant. shared_ptr > net_; + // if taking input from an ndarray, we need to hold references + object input_data_; + object input_labels_; }; class CaffeSGDSolver { @@ -301,9 +364,12 @@ class CaffeSGDSolver { // exception if param_file can't be opened CheckFile(param_file); solver_.reset(new SGDSolver(param_file)); + // we need to explicitly store the net wrapper, rather than constructing + // it on the fly, so that it can hold references to Python objects + net_.reset(new CaffeNet(solver_->net())); } - CaffeNet net() { return CaffeNet(solver_->net()); } + shared_ptr net() { return net_; } void Solve() { return solver_->Solve(); } void SolveResume(const string& resume_file) { CheckFile(resume_file); @@ -311,26 +377,29 @@ class CaffeSGDSolver { } protected: + shared_ptr net_; shared_ptr > solver_; }; // The boost python module definition. BOOST_PYTHON_MODULE(_caffe) { - boost::python::class_( + // below, we prepend an underscore to methods that will be replaced + // in Python + boost::python::class_ >( "Net", boost::python::init()) .def(boost::python::init()) - .def("Forward", &CaffeNet::Forward) - .def("ForwardPrefilled", &CaffeNet::ForwardPrefilled) - .def("Backward", &CaffeNet::Backward) - .def("set_mode_cpu", &CaffeNet::set_mode_cpu) - .def("set_mode_gpu", &CaffeNet::set_mode_gpu) - .def("set_phase_train", &CaffeNet::set_phase_train) - .def("set_phase_test", &CaffeNet::set_phase_test) - .def("set_device", &CaffeNet::set_device) - // rename blobs here since the pycaffe.py wrapper will replace it - .add_property("_blobs", &CaffeNet::blobs) - .add_property("layers", &CaffeNet::layers); + .def("Forward", &CaffeNet::Forward) + .def("ForwardPrefilled", &CaffeNet::ForwardPrefilled) + .def("Backward", &CaffeNet::Backward) + .def("set_mode_cpu", &CaffeNet::set_mode_cpu) + .def("set_mode_gpu", &CaffeNet::set_mode_gpu) + .def("set_phase_train", &CaffeNet::set_phase_train) + .def("set_phase_test", &CaffeNet::set_phase_test) + .def("set_device", &CaffeNet::set_device) + .add_property("_blobs", &CaffeNet::blobs) + .add_property("layers", &CaffeNet::layers) + .def("_set_input_arrays", &CaffeNet::set_input_arrays); boost::python::class_( "Blob", boost::python::no_init) diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 863d315b360..05187e98b31 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -3,8 +3,10 @@ interface. """ -from ._caffe import Net, SGDSolver from collections import OrderedDict +import numpy as np + +from ._caffe import Net, SGDSolver # we directly update methods from Net here (rather than using composition or # inheritance) so that nets created by caffe (e.g., by SGDSolver) will @@ -31,3 +33,11 @@ def _Net_params(self): if len(lr.blobs) > 0]) Net.params = _Net_params + +def _Net_set_input_arrays(self, data, labels): + if labels.ndim == 1: + labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis, + np.newaxis]) + return self._set_input_arrays(data, labels) + +Net.set_input_arrays = _Net_set_input_arrays diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index f1fe98df4a6..444e9cf4009 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -48,6 +48,12 @@ const Dtype* Blob::cpu_data() const { return (const Dtype*)data_->cpu_data(); } +template +void Blob::set_cpu_data(Dtype* data) { + CHECK(data); + data_->set_cpu_data(data); +} + template const Dtype* Blob::gpu_data() const { CHECK(data_); diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index d586924aebe..2991c81f559 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -56,6 +56,8 @@ Layer* GetLayer(const LayerParameter& param) { return new InnerProductLayer(param); case LayerParameter_LayerType_LRN: return new LRNLayer(param); + case LayerParameter_LayerType_MEMORY_DATA: + return new MemoryDataLayer(param); case LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS: return new MultinomialLogisticLossLayer(param); case LayerParameter_LayerType_POOLING: diff --git a/src/caffe/layers/memory_data_layer.cpp b/src/caffe/layers/memory_data_layer.cpp new file mode 100644 index 00000000000..60bce27b8c9 --- /dev/null +++ b/src/caffe/layers/memory_data_layer.cpp @@ -0,0 +1,51 @@ +// Copyright 2014 BVLC and contributors. + +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void MemoryDataLayer::SetUp(const vector*>& bottom, + vector*>* top) { + CHECK_EQ(bottom.size(), 0) << "Memory Data Layer takes no blobs as input."; + CHECK_EQ(top->size(), 2) << "Memory Data Layer takes two blobs as output."; + batch_size_ = this->layer_param_.memory_data_param().batch_size(); + datum_channels_ = this->layer_param_.memory_data_param().channels(); + datum_height_ = this->layer_param_.memory_data_param().height(); + datum_width_ = this->layer_param_.memory_data_param().width(); + datum_size_ = datum_channels_ * datum_height_ * datum_width_; + CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, height," + " and width must be specified and positive in memory_data_param"; + (*top)[0]->Reshape(batch_size_, datum_channels_, datum_height_, datum_width_); + (*top)[1]->Reshape(batch_size_, 1, 1, 1); + data_ = NULL; + labels_ = NULL; +} + +template +void MemoryDataLayer::Reset(Dtype* data, Dtype* labels, int n) { + CHECK(data); + CHECK(labels); + CHECK_EQ(n % batch_size_, 0) << "n must be a multiple of batch size"; + data_ = data; + labels_ = labels; + n_ = n; + pos_ = 0; +} + +template +Dtype MemoryDataLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset"; + (*top)[0]->set_cpu_data(data_ + pos_ * datum_size_); + (*top)[1]->set_cpu_data(labels_ + pos_); + pos_ = (pos_ + batch_size_) % n_; + return Dtype(0.); +} + +INSTANTIATE_CLASS(MemoryDataLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index e04e42c40a3..9fba81779f2 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -99,7 +99,7 @@ message SolverState { // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available ID: 22 (last added: power_param) +// LayerParameter next available ID: 23 (last added: memory_data_param) message LayerParameter { repeated string bottom = 2; // the name of the bottom blobs repeated string top = 3; // the name of the top blobs @@ -110,7 +110,7 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 29 (last added: HINGE_LOSS) + // LayerType next available ID: 30 (last added: MEMORY_DATA) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if @@ -133,6 +133,7 @@ message LayerParameter { INFOGAIN_LOSS = 13; INNER_PRODUCT = 14; LRN = 15; + MEMORY_DATA = 29; MULTINOMIAL_LOGISTIC_LOSS = 16; POOLING = 17; POWER = 26; @@ -166,6 +167,7 @@ message LayerParameter { optional InfogainLossParameter infogain_loss_param = 16; optional InnerProductParameter inner_product_param = 17; optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; optional PoolingParameter pooling_param = 19; optional PowerParameter power_param = 21; optional WindowDataParameter window_data_param = 20; @@ -289,6 +291,14 @@ message LRNParameter { optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; } +// Message that stores parameters used by MemoryDataLayer +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + // Message that stores parameters used by PoolingLayer message PoolingParameter { enum PoolMethod { diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index c33f3e60c0b..fec37d6e9ec 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -10,7 +10,7 @@ namespace caffe { SyncedMemory::~SyncedMemory() { - if (cpu_ptr_) { + if (cpu_ptr_ && own_cpu_data_) { CaffeFreeHost(cpu_ptr_); } @@ -25,10 +25,12 @@ inline void SyncedMemory::to_cpu() { CaffeMallocHost(&cpu_ptr_, size_); memset(cpu_ptr_, 0, size_); head_ = HEAD_AT_CPU; + own_cpu_data_ = true; break; case HEAD_AT_GPU: if (cpu_ptr_ == NULL) { CaffeMallocHost(&cpu_ptr_, size_); + own_cpu_data_ = true; } CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDeviceToHost)); head_ = SYNCED; @@ -64,6 +66,16 @@ const void* SyncedMemory::cpu_data() { return (const void*)cpu_ptr_; } +void SyncedMemory::set_cpu_data(void* data) { + CHECK(data); + if (own_cpu_data_) { + CaffeFreeHost(cpu_ptr_); + } + cpu_ptr_ = data; + head_ = HEAD_AT_CPU; + own_cpu_data_ = false; +} + const void* SyncedMemory::gpu_data() { to_gpu(); return (const void*)gpu_ptr_; diff --git a/src/caffe/test/test_memory_data_layer.cpp b/src/caffe/test/test_memory_data_layer.cpp new file mode 100644 index 00000000000..15f01bd41e3 --- /dev/null +++ b/src/caffe/test/test_memory_data_layer.cpp @@ -0,0 +1,108 @@ +// Copyright 2014 BVLC and contributors. + +#include + +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MemoryDataLayerTest : public ::testing::Test { + protected: + MemoryDataLayerTest() + : data_blob_(new Blob()), + label_blob_(new Blob()), + data_(new Blob()), labels_(new Blob()) {} + virtual void SetUp() { + batch_size_ = 8; + batches_ = 12; + channels_ = 4; + height_ = 7; + width_ = 11; + blob_top_vec_.push_back(data_blob_); + blob_top_vec_.push_back(label_blob_); + // pick random input data + FillerParameter filler_param; + GaussianFiller filler(filler_param); + data_->Reshape(batches_ * batch_size_, channels_, height_, width_); + labels_->Reshape(batches_ * batch_size_, 1, 1, 1); + filler.Fill(this->data_); + filler.Fill(this->labels_); + } + + virtual ~MemoryDataLayerTest() { + delete data_blob_; + delete label_blob_; + delete data_; + delete labels_; + } + int batch_size_; + int batches_; + int channels_; + int height_; + int width_; + // we don't really need blobs for the input data, but it makes it + // easier to call Filler + Blob* const data_; + Blob* const labels_; + // blobs for the top of MemoryDataLayer + Blob* const data_blob_; + Blob* const label_blob_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(MemoryDataLayerTest, Dtypes); + +TYPED_TEST(MemoryDataLayerTest, TestSetup) { + LayerParameter layer_param; + MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); + md_param->set_batch_size(this->batch_size_); + md_param->set_channels(this->channels_); + md_param->set_height(this->height_); + md_param->set_width(this->width_); + shared_ptr > layer( + new MemoryDataLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + EXPECT_EQ(this->data_blob_->num(), this->batch_size_); + EXPECT_EQ(this->data_blob_->channels(), this->channels_); + EXPECT_EQ(this->data_blob_->height(), this->height_); + EXPECT_EQ(this->data_blob_->width(), this->width_); + EXPECT_EQ(this->label_blob_->num(), this->batch_size_); + EXPECT_EQ(this->label_blob_->channels(), 1); + EXPECT_EQ(this->label_blob_->height(), 1); + EXPECT_EQ(this->label_blob_->width(), 1); +} + +// run through a few batches and check that the right data appears +TYPED_TEST(MemoryDataLayerTest, TestForward) { + LayerParameter layer_param; + MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); + md_param->set_batch_size(this->batch_size_); + md_param->set_channels(this->channels_); + md_param->set_height(this->height_); + md_param->set_width(this->width_); + shared_ptr > layer( + new MemoryDataLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer->Reset(this->data_->mutable_cpu_data(), + this->labels_->mutable_cpu_data(), this->data_->num()); + for (int i = 0; i < this->batches_ * 6; ++i) { + int batch_num = i % this->batches_; + layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int j = 0; j < this->data_blob_->count(); ++j) { + EXPECT_EQ(this->data_blob_->cpu_data()[j], + this->data_->cpu_data()[ + this->data_->offset(1) * this->batch_size_ * batch_num + j]); + } + for (int j = 0; j < this->label_blob_->count(); ++j) { + EXPECT_EQ(this->label_blob_->cpu_data()[j], + this->labels_->cpu_data()[this->batch_size_ * batch_num + j]); + } + } +} + +} // namespace caffe