Skip to content

Commit

Permalink
Merge pull request BVLC#294 from longjon/memory-data-layer
Browse files Browse the repository at this point in the history
Add a layer for in-memory data, and expose it to Python
  • Loading branch information
shelhamer committed May 2, 2014
2 parents cd6bfaf + 11d223d commit 183e9e3
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 19 deletions.
1 change: 1 addition & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Blob {
}

const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
Expand Down
9 changes: 7 additions & 2 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,29 @@ inline void CaffeFreeHost(void* ptr) {
class SyncedMemory {
public:
SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED) {}
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
own_cpu_data_(false) {}
explicit SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED) {}
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
own_cpu_data_(false) {}
~SyncedMemory();
const void* cpu_data();
void set_cpu_data(void* data);
const void* gpu_data();
void* mutable_cpu_data();
void* mutable_gpu_data();
enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED };
SyncedHead head() { return head_; }
size_t size() { return size_; }

private:
void to_cpu();
void to_gpu();
void* cpu_ptr_;
void* gpu_ptr_;
size_t size_;
SyncedHead head_;
bool own_cpu_data_;

DISABLE_COPY_AND_ASSIGN(SyncedMemory);
}; // class SyncedMemory
Expand Down
34 changes: 34 additions & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,40 @@ class LRNLayer : public Layer<Dtype> {
vector<Blob<Dtype>*> product_bottom_vec_;
};

template <typename Dtype>
class MemoryDataLayer : public Layer<Dtype> {
public:
explicit MemoryDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
// Reset should accept const pointers, but can't, because the memory
// will be given to Blob, which is mutable
void Reset(Dtype* data, Dtype* label, int n);
int datum_channels() { return datum_channels_; }
int datum_height() { return datum_height_; }
int datum_width() { return datum_width_; }
int batch_size() { return batch_size_; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }

Dtype* data_;
Dtype* labels_;
int datum_channels_;
int datum_height_;
int datum_width_;
int datum_size_;
int batch_size_;
int n_;
int pos_;
};

template <typename Dtype>
class MultinomialLogisticLossLayer : public Layer<Dtype> {
public:
Expand Down
95 changes: 82 additions & 13 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ struct CaffeNet {

virtual ~CaffeNet() {}

// this function is mostly redundant with the one below, but should go away
// with new pycaffe
inline void check_array_against_blob(
PyArrayObject* arr, Blob<float>* blob) {
CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
Expand All @@ -170,6 +172,29 @@ struct CaffeNet {
CHECK_EQ(dims[3], blob->width());
}

// generate Python exceptions for badly shaped or discontiguous arrays
inline void check_contiguous_array(PyArrayObject* arr, string name,
int channels, int height, int width) {
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
throw std::runtime_error(name + " must be C contiguous");
}
if (PyArray_NDIM(arr) != 4) {
throw std::runtime_error(name + " must be 4-d");
}
if (PyArray_TYPE(arr) != NPY_FLOAT32) {
throw std::runtime_error(name + " must be float32");
}
if (PyArray_DIMS(arr)[1] != channels) {
throw std::runtime_error(name + " has wrong number of channels");
}
if (PyArray_DIMS(arr)[2] != height) {
throw std::runtime_error(name + " has wrong height");
}
if (PyArray_DIMS(arr)[3] != width) {
throw std::runtime_error(name + " has wrong width");
}
}

// The actual forward function. It takes in a python list of numpy arrays as
// input and a python list of numpy arrays as output. The input and output
// should all have correct shapes, are single-precisionabcdnt- and
Expand Down Expand Up @@ -267,6 +292,41 @@ struct CaffeNet {
net_->ForwardPrefilled();
}

void set_input_arrays(object data_obj, object labels_obj) {
// check that this network has an input MemoryDataLayer
shared_ptr<MemoryDataLayer<float> > md_layer =
boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
if (!md_layer) {
throw std::runtime_error("set_input_arrays may only be called if the"
" first layer is a MemoryDataLayer");
}

// check that we were passed appropriately-sized contiguous memory
PyArrayObject* data_arr =
reinterpret_cast<PyArrayObject*>(data_obj.ptr());
PyArrayObject* labels_arr =
reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
md_layer->datum_height(), md_layer->datum_width());
check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
throw std::runtime_error("data and labels must have the same first"
" dimension");
}
if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
throw std::runtime_error("first dimensions of input arrays must be a"
" multiple of batch size");
}

// hold references
input_data_ = data_obj;
input_labels_ = labels_obj;

md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
static_cast<float*>(PyArray_DATA(labels_arr)),
PyArray_DIMS(data_arr)[0]);
}

// The caffe::Caffe utility functions.
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
Expand All @@ -292,6 +352,9 @@ struct CaffeNet {

// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
// if taking input from an ndarray, we need to hold references
object input_data_;
object input_labels_;
};

class CaffeSGDSolver {
Expand All @@ -301,36 +364,42 @@ class CaffeSGDSolver {
// exception if param_file can't be opened
CheckFile(param_file);
solver_.reset(new SGDSolver<float>(param_file));
// we need to explicitly store the net wrapper, rather than constructing
// it on the fly, so that it can hold references to Python objects
net_.reset(new CaffeNet(solver_->net()));
}

CaffeNet net() { return CaffeNet(solver_->net()); }
shared_ptr<CaffeNet> net() { return net_; }
void Solve() { return solver_->Solve(); }
void SolveResume(const string& resume_file) {
CheckFile(resume_file);
return solver_->Solve(resume_file);
}

protected:
shared_ptr<CaffeNet> net_;
shared_ptr<SGDSolver<float> > solver_;
};


// The boost python module definition.
BOOST_PYTHON_MODULE(_caffe) {
boost::python::class_<CaffeNet>(
// below, we prepend an underscore to methods that will be replaced
// in Python
boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
"Net", boost::python::init<string, string>())
.def(boost::python::init<string>())
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
// rename blobs here since the pycaffe.py wrapper will replace it
.add_property("_blobs", &CaffeNet::blobs)
.add_property("layers", &CaffeNet::layers);
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
.add_property("_blobs", &CaffeNet::blobs)
.add_property("layers", &CaffeNet::layers)
.def("_set_input_arrays", &CaffeNet::set_input_arrays);

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"Blob", boost::python::no_init)
Expand Down
12 changes: 11 additions & 1 deletion python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
interface.
"""

from ._caffe import Net, SGDSolver
from collections import OrderedDict
import numpy as np

from ._caffe import Net, SGDSolver

# we directly update methods from Net here (rather than using composition or
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
Expand All @@ -31,3 +33,11 @@ def _Net_params(self):
if len(lr.blobs) > 0])

Net.params = _Net_params

def _Net_set_input_arrays(self, data, labels):
if labels.ndim == 1:
labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis,
np.newaxis])
return self._set_input_arrays(data, labels)

Net.set_input_arrays = _Net_set_input_arrays
6 changes: 6 additions & 0 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ const Dtype* Blob<Dtype>::cpu_data() const {
return (const Dtype*)data_->cpu_data();
}

template <typename Dtype>
void Blob<Dtype>::set_cpu_data(Dtype* data) {
CHECK(data);
data_->set_cpu_data(data);
}

template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_data() const {
CHECK(data_);
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new InnerProductLayer<Dtype>(param);
case LayerParameter_LayerType_LRN:
return new LRNLayer<Dtype>(param);
case LayerParameter_LayerType_MEMORY_DATA:
return new MemoryDataLayer<Dtype>(param);
case LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS:
return new MultinomialLogisticLossLayer<Dtype>(param);
case LayerParameter_LayerType_POOLING:
Expand Down
51 changes: 51 additions & 0 deletions src/caffe/layers/memory_data_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2014 BVLC and contributors.

#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {

template <typename Dtype>
void MemoryDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 0) << "Memory Data Layer takes no blobs as input.";
CHECK_EQ(top->size(), 2) << "Memory Data Layer takes two blobs as output.";
batch_size_ = this->layer_param_.memory_data_param().batch_size();
datum_channels_ = this->layer_param_.memory_data_param().channels();
datum_height_ = this->layer_param_.memory_data_param().height();
datum_width_ = this->layer_param_.memory_data_param().width();
datum_size_ = datum_channels_ * datum_height_ * datum_width_;
CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, height,"
" and width must be specified and positive in memory_data_param";
(*top)[0]->Reshape(batch_size_, datum_channels_, datum_height_, datum_width_);
(*top)[1]->Reshape(batch_size_, 1, 1, 1);
data_ = NULL;
labels_ = NULL;
}

template <typename Dtype>
void MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n) {
CHECK(data);
CHECK(labels);
CHECK_EQ(n % batch_size_, 0) << "n must be a multiple of batch size";
data_ = data;
labels_ = labels;
n_ = n;
pos_ = 0;
}

template <typename Dtype>
Dtype MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";
(*top)[0]->set_cpu_data(data_ + pos_ * datum_size_);
(*top)[1]->set_cpu_data(labels_ + pos_);
pos_ = (pos_ + batch_size_) % n_;
return Dtype(0.);
}

INSTANTIATE_CLASS(MemoryDataLayer);

} // namespace caffe
14 changes: 12 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ message SolverState {

// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available ID: 22 (last added: power_param)
// LayerParameter next available ID: 23 (last added: memory_data_param)
message LayerParameter {
repeated string bottom = 2; // the name of the bottom blobs
repeated string top = 3; // the name of the top blobs
Expand All @@ -110,7 +110,7 @@ message LayerParameter {
// line above the enum. Update the next available ID when you add a new
// LayerType.
//
// LayerType next available ID: 29 (last added: HINGE_LOSS)
// LayerType next available ID: 30 (last added: MEMORY_DATA)
enum LayerType {
// "NONE" layer type is 0th enum element so that we don't cause confusion
// by defaulting to an existent LayerType (instead, should usually error if
Expand All @@ -133,6 +133,7 @@ message LayerParameter {
INFOGAIN_LOSS = 13;
INNER_PRODUCT = 14;
LRN = 15;
MEMORY_DATA = 29;
MULTINOMIAL_LOGISTIC_LOSS = 16;
POOLING = 17;
POWER = 26;
Expand Down Expand Up @@ -166,6 +167,7 @@ message LayerParameter {
optional InfogainLossParameter infogain_loss_param = 16;
optional InnerProductParameter inner_product_param = 17;
optional LRNParameter lrn_param = 18;
optional MemoryDataParameter memory_data_param = 22;
optional PoolingParameter pooling_param = 19;
optional PowerParameter power_param = 21;
optional WindowDataParameter window_data_param = 20;
Expand Down Expand Up @@ -289,6 +291,14 @@ message LRNParameter {
optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
}

// Message that stores parameters used by MemoryDataLayer
message MemoryDataParameter {
optional uint32 batch_size = 1;
optional uint32 channels = 2;
optional uint32 height = 3;
optional uint32 width = 4;
}

// Message that stores parameters used by PoolingLayer
message PoolingParameter {
enum PoolMethod {
Expand Down
Loading

0 comments on commit 183e9e3

Please sign in to comment.