Skip to content

Commit

Permalink
Merge pull request BVLC#311 from shelhamer/python-fixes
Browse files Browse the repository at this point in the history
Improve python wrapper
  • Loading branch information
shelhamer committed May 20, 2014
2 parents 47ad2f9 + afab167 commit be83632
Show file tree
Hide file tree
Showing 21 changed files with 1,566 additions and 1,276 deletions.
16 changes: 8 additions & 8 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ Even in CPU mode, computing predictions on an image takes only 20 ms when images

### Examples

* [Image Classification \[notebook\]][imagenet_classification]: classify images with the pretrained ImageNet model by the Python interface.
* [Detection \[notebook\]][detection]: run a pretrained model as a detector in Python.
* [Visualizing Features and Filters \[notebook\]][visualizing_filters]: extracting features and visualizing trained filters with an example image, viewed layer-by-layer.
* [LeNet / MNIST Demo](/mnist.html): end-to-end training and testing of LeNet on MNIST.
* [CIFAR-10 Demo](/cifar10.html): training and testing on the CIFAR-10 data.
* [Training ImageNet](/imagenet_training.html): end-to-end training of an ImageNet classifier.
* [Feature extraction with C++](/feature_extraction.html): feature extraction using pre-trained model
* [Running Pretrained ImageNet \[notebook\]][pretrained_imagenet]: run classification with the pretrained ImageNet model using the Python interface.
* [Running Detection \[notebook\]][imagenet_detection]: run a pretrained model as a detector.
* [Visualizing Features and Filters \[notebook\]][visualizing_filters]: trained filters and an example image, viewed layer-by-layer.

[pretrained_imagenet]: http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/imagenet_pretrained.ipynb
[imagenet_detection]: http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/selective_search_demo.ipynb
* [Training ImageNet](/imagenet_training.html): recipe for end-to-end training of an ImageNet classifier.
* [Feature extraction with C++](/feature_extraction.html): feature extraction using pre-trained model.

[imagenet_classification]: http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/imagenet_classification.ipynb
[detection]: http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/detection_search_demo.ipynb
[visualizing_filters]: http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/filter_visualization.ipynb

## Citing Caffe
Expand Down
392 changes: 176 additions & 216 deletions examples/selective_search_demo.ipynb → examples/detection.ipynb

Large diffs are not rendered by default.

99 changes: 51 additions & 48 deletions examples/filter_visualization.ipynb

Large diffs are not rendered by default.

410 changes: 410 additions & 0 deletions examples/imagenet_classification.ipynb

Large diffs are not rendered by default.

272 changes: 0 additions & 272 deletions examples/imagenet_pretrained.ipynb

This file was deleted.

2 changes: 2 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class Net {
inline int num_outputs() { return net_output_blobs_.size(); }
inline vector<Blob<Dtype>*>& input_blobs() { return net_input_blobs_; }
inline vector<Blob<Dtype>*>& output_blobs() { return net_output_blobs_; }
inline vector<int>& input_blob_indices() { return net_input_blob_indices_; }
inline vector<int>& output_blob_indices() { return net_output_blob_indices_; }
// has_blob and blob_by_name are inspired by
// https://github.com/kencoken/caffe/commit/f36e71569455c9fbb4bf8a63c2d53224e32a4e7b
// Access intermediary computation layers, testing with centre image only
Expand Down
3 changes: 3 additions & 0 deletions python/caffe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .pycaffe import Net, SGDSolver
from .classifier import Classifier
from .detector import Detector
import io
150 changes: 33 additions & 117 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2014 BVLC and contributors.
// pycaffe provides a wrapper of the caffe::Net class as well as some
// caffe::Caffe functions so that one could easily call it from Python.
// Note that for python, we will simply use float as the data type.
// Note that for Python, we will simply use float as the data type.

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

Expand Down Expand Up @@ -46,7 +46,7 @@ static void CheckFile(const string& filename) {
}

// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
// to Python
// to Python
class CaffeBlob {
public:
CaffeBlob(const shared_ptr<Blob<float> > &blob, const string& name)
Expand All @@ -70,9 +70,9 @@ class CaffeBlob {
};


// we need another wrapper (used as boost::python's HeldType) that receives a
// self PyObject * which we can use as ndarray.base, so that data/diff memory
// is not freed while still being used in Python
// We need another wrapper (used as boost::python's HeldType) that receives a
// self PyObject * which we can use as ndarray.base, so that data/diff memory
// is not freed while still being used in Python.
class CaffeBlobWrap : public CaffeBlob {
public:
CaffeBlobWrap(PyObject *p, const CaffeBlob &blob)
Expand Down Expand Up @@ -158,21 +158,7 @@ struct CaffeNet {

virtual ~CaffeNet() {}

// this function is mostly redundant with the one below, but should go away
// with new pycaffe
inline void check_array_against_blob(
PyArrayObject* arr, Blob<float>* blob) {
CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
CHECK_EQ(PyArray_NDIM(arr), 4);
CHECK_EQ(PyArray_ITEMSIZE(arr), 4);
npy_intp* dims = PyArray_DIMS(arr);
CHECK_EQ(dims[0], blob->num());
CHECK_EQ(dims[1], blob->channels());
CHECK_EQ(dims[2], blob->height());
CHECK_EQ(dims[3], blob->width());
}

// generate Python exceptions for badly shaped or discontiguous arrays
// Generate Python exceptions for badly shaped or discontiguous arrays.
inline void check_contiguous_array(PyArrayObject* arr, string name,
int channels, int height, int width) {
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
Expand All @@ -195,101 +181,12 @@ struct CaffeNet {
}
}

// The actual forward function. It takes in a python list of numpy arrays as
// input and a python list of numpy arrays as output. The input and output
// should all have correct shapes, are single-precisionabcdnt- and
// c contiguous.
void Forward(list bottom, list top) {
vector<Blob<float>*>& input_blobs = net_->input_blobs();
CHECK_EQ(len(bottom), input_blobs.size());
CHECK_EQ(len(top), net_->num_outputs());
// First, copy the input
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, input_blobs[i]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr),
sizeof(float) * input_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(input_blobs[i]->mutable_gpu_data(), PyArray_DATA(arr),
sizeof(float) * input_blobs[i]->count(), cudaMemcpyHostToDevice);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
const vector<Blob<float>*>& output_blobs = net_->ForwardPrefilled();
// LOG(INFO) << "End";
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, output_blobs[i]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(),
sizeof(float) * output_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(PyArray_DATA(arr), output_blobs[i]->gpu_data(),
sizeof(float) * output_blobs[i]->count(), cudaMemcpyDeviceToHost);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
void Forward() {
net_->ForwardPrefilled();
}

void Backward(list top_diff, list bottom_diff) {
vector<Blob<float>*>& output_blobs = net_->output_blobs();
vector<Blob<float>*>& input_blobs = net_->input_blobs();
CHECK_EQ(len(bottom_diff), input_blobs.size());
CHECK_EQ(len(top_diff), output_blobs.size());
// First, copy the output diff
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, output_blobs[i]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr),
sizeof(float) * output_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), PyArray_DATA(arr),
sizeof(float) * output_blobs[i]->count(), cudaMemcpyHostToDevice);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
void Backward() {
net_->Backward();
// LOG(INFO) << "End";
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, input_blobs[i]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(),
sizeof(float) * input_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(PyArray_DATA(arr), input_blobs[i]->gpu_diff(),
sizeof(float) * input_blobs[i]->count(), cudaMemcpyDeviceToHost);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
}

void ForwardPrefilled() {
net_->ForwardPrefilled();
}

void set_input_arrays(object data_obj, object labels_obj) {
Expand Down Expand Up @@ -350,6 +247,24 @@ struct CaffeNet {
return result;
}

list inputs() {
list input_blob_names;
for (int i = 0; i < net_->input_blob_indices().size(); ++i) {
input_blob_names.append(
net_->blob_names()[net_->input_blob_indices()[i]]);
}
return input_blob_names;
}

list outputs() {
list output_blob_names;
for (int i = 0; i < net_->output_blob_indices().size(); ++i) {
output_blob_names.append(
net_->blob_names()[net_->output_blob_indices()[i]]);
}
return output_blob_names;
}

// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
// if taking input from an ndarray, we need to hold references
Expand Down Expand Up @@ -382,23 +297,24 @@ class CaffeSGDSolver {
};


// The boost python module definition.
// The boost_python module definition.
BOOST_PYTHON_MODULE(_caffe) {
// below, we prepend an underscore to methods that will be replaced
// in Python
// in Python
boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
"Net", boost::python::init<string, string>())
.def(boost::python::init<string>())
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
.def("_forward", &CaffeNet::Forward)
.def("_backward", &CaffeNet::Backward)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
.add_property("_blobs", &CaffeNet::blobs)
.add_property("layers", &CaffeNet::layers)
.add_property("inputs", &CaffeNet::inputs)
.add_property("outputs", &CaffeNet::outputs)
.def("_set_input_arrays", &CaffeNet::set_input_arrays);

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
Expand Down
86 changes: 86 additions & 0 deletions python/caffe/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
"""
Classifier is an image classifier specialization of Net.
"""

import numpy as np

import caffe


class Classifier(caffe.Net):
"""
Classifier extends Net for image class prediction
by scaling, center cropping, or oversampling.
"""
def __init__(self, model_file, pretrained_file, image_dims=None,
gpu=False, mean_file=None, input_scale=None, channel_swap=None):
"""
Take
image_dims: dimensions to scale input for cropping/sampling.
Default is to scale to net input size for whole-image crop.
gpu, mean_file, input_scale, channel_swap: convenience params for
setting mode, mean, input scale, and channel order.
"""
caffe.Net.__init__(self, model_file, pretrained_file)
self.set_phase_test()

if gpu:
self.set_mode_gpu()
else:
self.set_mode_cpu()

if mean_file:
self.set_mean(self.inputs[0], mean_file)
if input_scale:
self.set_input_scale(self.inputs[0], input_scale)
if channel_swap:
self.set_channel_swap(self.inputs[0], channel_swap)

self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:])
if not image_dims:
image_dims = self.crop_dims
self.image_dims = image_dims


def predict(self, inputs, oversample=True):
"""
Predict classification probabilities of inputs.
Take
inputs: iterable of (H x W x K) input ndarrays.
oversample: average predictions across center, corners, and mirrors
when True (default). Center-only prediction when False.
Give
predictions: (N x C) ndarray of class probabilities
for N images and C classes.
"""
# Scale to standardize input dimensions.
inputs = np.asarray([caffe.io.resize_image(im, self.image_dims)
for im in inputs])

if oversample:
# Generate center, corner, and mirrored crops.
inputs = caffe.io.oversample(inputs, self.crop_dims)
else:
# Take center crop.
center = np.array(self.image_dims) / 2.0
crop = np.tile(center, (1, 2))[0] + np.concatenate([
-self.crop_dims / 2.0,
self.crop_dims / 2.0
])
inputs = inputs[:, crop[0]:crop[2], crop[1]:crop[3], :]

# Classify
caffe_in = np.asarray([self.preprocess(self.inputs[0], in_)
for in_ in inputs])
out = self.forward_all(**{self.inputs[0]: caffe_in})
predictions = out[self.outputs[0]].squeeze(axis=(2,3))

# For oversampling, average predictions across crops.
if oversample:
predictions = predictions.reshape((len(predictions) / 10, 10, -1))
predictions = predictions.mean(1)

return predictions
Empty file removed python/caffe/detection/__init__.py
Empty file.
Loading

0 comments on commit be83632

Please sign in to comment.