-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve python wrapper #311
Changes from 12 commits
51f276e
47ec9ac
8da2a32
872ddf3
56ca978
96cd02d
0e5a5cf
9d4324e
ac5e6fa
af0b857
8af33e8
1b23680
a3b307a
a7f6750
459c8c1
025c64e
5d584c2
5102413
37123a5
738c875
50d0b6d
6b85fd0
bf4d726
2fc32d5
111df0e
02ecf1d
8830dc5
42bf2d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// Copyright 2014 BVLC and contributors. | ||
// pycaffe provides a wrapper of the caffe::Net class as well as some | ||
// caffe::Caffe functions so that one could easily call it from Python. | ||
// caffe::Caffe functions so that one could easily call it from python. | ||
// Note that for python, we will simply use float as the data type. | ||
|
||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | ||
|
@@ -33,7 +33,7 @@ using boost::python::handle; | |
using boost::python::vector_indexing_suite; | ||
|
||
// for convenience, check that input files can be opened, and raise an | ||
// exception that boost will send to Python if not (caffe could still crash | ||
// exception that boost will send to python if not (caffe could still crash | ||
// later if the input files are disturbed before they are actually used, but | ||
// this saves frustration in most cases) | ||
static void CheckFile(const string& filename) { | ||
|
@@ -46,7 +46,7 @@ static void CheckFile(const string& filename) { | |
} | ||
|
||
// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass | ||
// to Python | ||
// to python | ||
class CaffeBlob { | ||
public: | ||
CaffeBlob(const shared_ptr<Blob<float> > &blob, const string& name) | ||
|
@@ -70,9 +70,9 @@ class CaffeBlob { | |
}; | ||
|
||
|
||
// we need another wrapper (used as boost::python's HeldType) that receives a | ||
// self PyObject * which we can use as ndarray.base, so that data/diff memory | ||
// is not freed while still being used in Python | ||
// We need another wrapper (used as boost::python's HeldType) that receives a | ||
// self PyObject * which we can use as ndarray.base, so that data/diff memory | ||
// is not freed while still being used in python. | ||
class CaffeBlobWrap : public CaffeBlob { | ||
public: | ||
CaffeBlobWrap(PyObject *p, const CaffeBlob &blob) | ||
|
@@ -142,8 +142,9 @@ struct CaffeNet { | |
} | ||
|
||
CaffeNet(string param_file, string pretrained_param_file) { | ||
Init(param_file); | ||
CheckFile(param_file); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this right? This looks like part of a previous commit being accidentally reverted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching that–fixed in a3b307a. |
||
CheckFile(pretrained_param_file); | ||
net_.reset(new Net<float>(param_file)); | ||
net_->CopyTrainedLayersFrom(pretrained_param_file); | ||
} | ||
|
||
|
@@ -158,18 +159,15 @@ struct CaffeNet { | |
|
||
virtual ~CaffeNet() {} | ||
|
||
// this function is mostly redundant with the one below, but should go away | ||
// with new pycaffe | ||
// Check that an array is acceptable for blob assignment | ||
// as described in the preface to Forward(). | ||
inline void check_array_against_blob( | ||
PyArrayObject* arr, Blob<float>* blob) { | ||
CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS); | ||
CHECK_EQ(PyArray_NDIM(arr), 4); | ||
CHECK_EQ(PyArray_ITEMSIZE(arr), 4); | ||
npy_intp* dims = PyArray_DIMS(arr); | ||
CHECK_EQ(dims[0], blob->num()); | ||
CHECK_EQ(dims[1], blob->channels()); | ||
CHECK_EQ(dims[2], blob->height()); | ||
CHECK_EQ(dims[3], blob->width()); | ||
PyArrayObject* arr, Blob<float>* blob, string name) { | ||
check_contiguous_array(arr, name, blob->channels(), blob->height(), | ||
blob->width()); | ||
if (PyArray_DIMS(arr)[0] != blob->num()) { | ||
throw std::runtime_error(name + " has wrong batch size"); | ||
} | ||
} | ||
|
||
// generate Python exceptions for badly shaped or discontiguous arrays | ||
|
@@ -197,8 +195,7 @@ struct CaffeNet { | |
|
||
// The actual forward function. It takes in a python list of numpy arrays as | ||
// input and a python list of numpy arrays as output. The input and output | ||
// should all have correct shapes, are single-precisionabcdnt- and | ||
// c contiguous. | ||
// should all have correct shapes, be single-precision, and be C-contiguous. | ||
void Forward(list bottom, list top) { | ||
vector<Blob<float>*>& input_blobs = net_->input_blobs(); | ||
CHECK_EQ(len(bottom), input_blobs.size()); | ||
|
@@ -207,7 +204,8 @@ struct CaffeNet { | |
for (int i = 0; i < input_blobs.size(); ++i) { | ||
object elem = bottom[i]; | ||
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); | ||
check_array_against_blob(arr, input_blobs[i]); | ||
check_array_against_blob(arr, input_blobs[i], | ||
net_->blob_names()[net_->input_blob_indices()[i]]); | ||
switch (Caffe::mode()) { | ||
case Caffe::CPU: | ||
memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr), | ||
|
@@ -227,7 +225,8 @@ struct CaffeNet { | |
for (int i = 0; i < output_blobs.size(); ++i) { | ||
object elem = top[i]; | ||
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); | ||
check_array_against_blob(arr, output_blobs[i]); | ||
check_array_against_blob(arr, output_blobs[i], | ||
net_->blob_names()[net_->input_blob_indices()[i]]); | ||
switch (Caffe::mode()) { | ||
case Caffe::CPU: | ||
memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(), | ||
|
@@ -252,7 +251,8 @@ struct CaffeNet { | |
for (int i = 0; i < output_blobs.size(); ++i) { | ||
object elem = top_diff[i]; | ||
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); | ||
check_array_against_blob(arr, output_blobs[i]); | ||
check_array_against_blob(arr, output_blobs[i], | ||
net_->blob_names()[net_->input_blob_indices()[i]]); | ||
switch (Caffe::mode()) { | ||
case Caffe::CPU: | ||
memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr), | ||
|
@@ -272,7 +272,8 @@ struct CaffeNet { | |
for (int i = 0; i < input_blobs.size(); ++i) { | ||
object elem = bottom_diff[i]; | ||
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); | ||
check_array_against_blob(arr, input_blobs[i]); | ||
check_array_against_blob(arr, input_blobs[i], | ||
net_->blob_names()[net_->input_blob_indices()[i]]); | ||
switch (Caffe::mode()) { | ||
case Caffe::CPU: | ||
memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(), | ||
|
@@ -292,6 +293,10 @@ struct CaffeNet { | |
net_->ForwardPrefilled(); | ||
} | ||
|
||
void BackwardPrefilled() { | ||
net_->Backward(); | ||
} | ||
|
||
void set_input_arrays(object data_obj, object labels_obj) { | ||
// check that this network has an input MemoryDataLayer | ||
shared_ptr<MemoryDataLayer<float> > md_layer = | ||
|
@@ -350,6 +355,24 @@ struct CaffeNet { | |
return result; | ||
} | ||
|
||
list inputs() { | ||
list input_blob_names; | ||
for (vector<int>::iterator it = net_->input_blob_indices().begin(); | ||
it != net_->input_blob_indices().end(); ++it) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been wondering why we don't use iterators since #112 (comment). Every other caffe loop over a vector uses an index. Does anyone know a good reason for this? In either case, it would be nice to be consistent throughout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
input_blob_names.append(net_->blob_names()[*it]); | ||
} | ||
return input_blob_names; | ||
} | ||
|
||
list outputs() { | ||
list output_blob_names; | ||
for (vector<int>::iterator it = net_->output_blob_indices().begin(); | ||
it != net_->output_blob_indices().end(); ++it) { | ||
output_blob_names.append(net_->blob_names()[*it]); | ||
} | ||
return output_blob_names; | ||
} | ||
|
||
// The pointer to the internal caffe::Net instant. | ||
shared_ptr<Net<float> > net_; | ||
// if taking input from an ndarray, we need to hold references | ||
|
@@ -392,13 +415,16 @@ BOOST_PYTHON_MODULE(_caffe) { | |
.def("Forward", &CaffeNet::Forward) | ||
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled) | ||
.def("Backward", &CaffeNet::Backward) | ||
.def("BackwardPrefilled", &CaffeNet::BackwardPrefilled) | ||
.def("set_mode_cpu", &CaffeNet::set_mode_cpu) | ||
.def("set_mode_gpu", &CaffeNet::set_mode_gpu) | ||
.def("set_phase_train", &CaffeNet::set_phase_train) | ||
.def("set_phase_test", &CaffeNet::set_phase_test) | ||
.def("set_device", &CaffeNet::set_device) | ||
.add_property("_blobs", &CaffeNet::blobs) | ||
.add_property("layers", &CaffeNet::layers) | ||
.add_property("inputs", &CaffeNet::inputs) | ||
.add_property("outputs", &CaffeNet::outputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra spaces here? |
||
.def("_set_input_arrays", &CaffeNet::set_input_arrays); | ||
|
||
boost::python::class_<CaffeBlob, CaffeBlobWrap>( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really? See python.org, Python wiki page, etc. Just sayin'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, fair enough. Capital is fine and used throughout now.