Skip to content

Commit

Permalink
take blob args as ndarrays and assign on the python side
Browse files Browse the repository at this point in the history
Take blob args and give blob returns as single ndarrays instead of lists
of arrays.

Assign the net blobs and diffs as needed on the python side, which
reduces copies and simplifies the C++ side of the wrapper.

Thanks @longjon for the suggestion.
  • Loading branch information
shelhamer committed May 16, 2014
1 parent 025c64e commit 5d584c2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 184 deletions.
119 changes: 5 additions & 114 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,7 @@ struct CaffeNet {

virtual ~CaffeNet() {}

// Check that an array is acceptable for blob assignment
// as described in the preface to Forward().
inline void check_array_against_blob(
PyArrayObject* arr, Blob<float>* blob, string name) {
check_contiguous_array(arr, name, blob->channels(), blob->height(),
blob->width());
if (PyArray_DIMS(arr)[0] != blob->num()) {
throw std::runtime_error(name + " has wrong batch size");
}
}

// generate Python exceptions for badly shaped or discontiguous arrays
// Generate Python exceptions for badly shaped or discontiguous arrays.
inline void check_contiguous_array(PyArrayObject* arr, string name,
int channels, int height, int width) {
if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
Expand All @@ -192,107 +181,11 @@ struct CaffeNet {
}
}

// The actual forward function. It takes in a Python list of numpy arrays as
// input and a Python list of numpy arrays as output. The input and output
// should all have correct shapes, be single-precision, and be C-contiguous.
void Forward(list bottom, list top) {
vector<Blob<float>*>& input_blobs = net_->input_blobs();
CHECK_EQ(len(bottom), input_blobs.size());
CHECK_EQ(len(top), net_->num_outputs());
// First, copy the input
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, input_blobs[i],
net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr),
sizeof(float) * input_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(input_blobs[i]->mutable_gpu_data(), PyArray_DATA(arr),
sizeof(float) * input_blobs[i]->count(), cudaMemcpyHostToDevice);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
const vector<Blob<float>*>& output_blobs = net_->ForwardPrefilled();
// LOG(INFO) << "End";
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, output_blobs[i],
net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(),
sizeof(float) * output_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(PyArray_DATA(arr), output_blobs[i]->gpu_data(),
sizeof(float) * output_blobs[i]->count(), cudaMemcpyDeviceToHost);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
}

void Backward(list top_diff, list bottom_diff) {
vector<Blob<float>*>& output_blobs = net_->output_blobs();
vector<Blob<float>*>& input_blobs = net_->input_blobs();
CHECK_EQ(len(bottom_diff), input_blobs.size());
CHECK_EQ(len(top_diff), output_blobs.size());
// First, copy the output diff
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, output_blobs[i],
net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr),
sizeof(float) * output_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(output_blobs[i]->mutable_gpu_diff(), PyArray_DATA(arr),
sizeof(float) * output_blobs[i]->count(), cudaMemcpyHostToDevice);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
net_->Backward();
// LOG(INFO) << "End";
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
check_array_against_blob(arr, input_blobs[i],
net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(),
sizeof(float) * input_blobs[i]->count());
break;
case Caffe::GPU:
cudaMemcpy(PyArray_DATA(arr), input_blobs[i]->gpu_diff(),
sizeof(float) * input_blobs[i]->count(), cudaMemcpyDeviceToHost);
break;
default:
LOG(FATAL) << "Unknown Caffe mode.";
} // switch (Caffe::mode())
}
}

void ForwardPrefilled() {
void Forward() {
net_->ForwardPrefilled();
}

void BackwardPrefilled() {
void Backward() {
net_->Backward();
}

Expand Down Expand Up @@ -411,10 +304,8 @@ BOOST_PYTHON_MODULE(_caffe) {
boost::python::class_<CaffeNet, shared_ptr<CaffeNet> >(
"Net", boost::python::init<string, string>())
.def(boost::python::init<string>())
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
.def("BackwardPrefilled", &CaffeNet::BackwardPrefilled)
.def("_forward", &CaffeNet::Forward)
.def("_backward", &CaffeNet::Backward)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
Expand Down
137 changes: 67 additions & 70 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,32 @@ def _Net_forward(self, blobs=None, **kwargs):

Take
blobs: list of blobs to return in addition to output blobs.
kwargs: Keys are input blob names and values are lists of inputs.
Images must be (H x W x K) ndarrays.
If None, input is taken from data layers by ForwardPrefilled().
kwargs: Keys are input blob names and values are blob ndarrays.
For turning images into input blobs, see format_image().
If None, input is taken from data layers.

Give
outs: {blob name: list of blobs ndarrays} dict.
outs: {blob name: blob ndarray} dict.
"""
if blobs is None:
blobs = []

if not kwargs:
# Carry out prefilled forward pass and unpack output.
self.ForwardPrefilled()
out_blobs = [self.blobs[out].data for out in self.outputs]
else:
# Create input and output blobs according to net defined shapes
# and make arrays single and C-contiguous as Caffe expects.
in_blobs = [np.ascontiguousarray(np.concatenate(kwargs[in_]),
dtype=np.float32)
for in_ in self.inputs]
out_blobs = [np.empty(self.blobs[out].data.shape, dtype=np.float32)
for out in self.outputs]
if kwargs:
if set(kwargs.keys()) != set(self.inputs):
raise Exception('Input blob arguments do not match net inputs.')
# Set input according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for in_, blob in kwargs.iteritems():
if blob.shape[0] != self.blobs[in_].num:
raise Exception('Input is not batch sized')
if blob.ndim != 4:
raise Exception('{} blob is not 4-d'.format(in_))
self.blobs[in_].data[...] = blob

self.Forward(in_blobs, out_blobs)
self._forward()

# Unpack blobs to extract
outs = {}
out_blobs.extend([self.blobs[blob].data for blob in blobs])
out_blob_names = self.outputs + blobs
for out, out_blob in zip(out_blob_names, out_blobs):
outs[out] = [out_blob[ix, :, :, :]
for ix in range(out_blob.shape[0])]
outs = {out: self.blobs[out].data for out in set(self.outputs + blobs)}
return outs


Expand All @@ -87,37 +81,31 @@ def _Net_backward(self, diffs=None, **kwargs):

Take
diffs: list of diffs to return in addition to bottom diffs.
kwargs: Keys are output blob names and values are lists of diffs.
If None, top diffs are taken from loss by BackwardPrefilled().
kwargs: Keys are output blob names and values are diff ndarrays.
If None, top diffs are taken from forward loss.

Give
outs: {blob name: list of diffs} dict.
outs: {blob name: diff ndarray} dict.
"""
if diffs is None:
diffs = []

if not kwargs:
# Carry out backward with forward loss diffs and unpack bottom diffs.
self.BackwardPrefilled()
out_diffs = [self.blobs[in_].diff for in_ in self.inputs]
else:
# Create top and bottom diffs according to net defined shapes
# and make arrays single and C-contiguous as Caffe expects.
top_diffs = [np.ascontiguousarray(np.concatenate(kwargs[out]),
dtype=np.float32)
for out in self.outputs]
out_diffs = [np.empty(self.blobs[bottom].diff.shape, dtype=np.float32)
for bottom in self.inputs]
if kwargs:
if set(kwargs.keys()) != set(self.outputs):
raise Exception('Top diff arguments do not match net outputs.')
# Set top diffs according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for top, diff in kwargs.iteritems():
if diff.shape[0] != self.blobs[top].num:
raise Exception('Diff is not batch sized')
if diff.ndim != 4:
raise Exception('{} diff is not 4-d'.format(top))
self.blobs[top].diff[...] = diff

self.Backward(top_diffs, out_diffs)
self._backward()

# Unpack diffs to extract
outs = {}
out_diffs.extend([self.blobs[diff].diff for diff in diffs])
out_diff_names = self.inputs + diffs
for out, out_diff in zip(out_diff_names, out_diffs):
outs[out] = [out_diff[ix, :, :, :]
for ix in range(out_diff.shape[0])]
outs = {out: self.blobs[out].diff for out in set(self.inputs + diffs)}
return outs


Expand All @@ -127,23 +115,26 @@ def _Net_forward_all(self, blobs=None, **kwargs):

Take
blobs: list of blobs to extract as in forward()
kwargs: Keys are input blob names and values are lists of blobs.
kwargs: Keys are input blob names and values are blob ndarrays.
Refer to forward().

Give
all_outs: {blob name: list of blobs} dict.
"""
# Collect outputs from batches
all_outs = {out: [] for out in self.outputs + blobs}
all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
for batch in self._batch(kwargs):
outs = self.forward(blobs=blobs, **batch)
for out, out_blobs in outs.items():
all_outs[out].extend(out_blobs)
# Discard padding at the end.
for out, out_blob in outs.iteritems():
all_outs[out].extend(out_blob)
# Package in ndarray.
for out in all_outs:
all_outs[out] = np.asarray(all_outs[out])
# Discard padding.
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
if pad:
for out in all_outs:
del all_outs[out][-pad:]
all_outs[out] = all_outs[out][:-pad]
return all_outs


Expand All @@ -154,17 +145,17 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
Take
blobs: list of blobs to extract as in forward()
diffs: list of diffs to extract as in backward()
kwargs: Keys are input (for forward) and output (for backward) blob
names and values are lists of blobs. Refer to forward() and backward().
kwargs: Keys are input (for forward) and output (for backward) blob names
and values are ndarrays. Refer to forward() and backward().
Prefilled variants are called for lack of input or output blobs.

Give
all_blobs: {blob name: list of blobs} dict.
all_diffs: {blob name: list of diffs} dict.
all_blobs: {blob name: blob ndarray} dict.
all_diffs: {blob name: diff ndarray} dict.
"""
# Batch blobs and diffs.
all_outs = {out: [] for out in self.outputs + (blobs or [])}
all_diffs = {diff: [] for diff in self.inputs + (diffs or [])}
all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
all_diffs = {diff: [] for diff in set(self.inputs + (diffs or []))}
forward_batches = self._batch({in_: kwargs[in_]
for in_ in self.inputs if in_ in kwargs})
backward_batches = self._batch({out: kwargs[out]
Expand All @@ -173,17 +164,20 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
batch_blobs = self.forward(blobs=blobs, **fb)
batch_diffs = self.backward(diffs=diffs, **bb)
for out, out_blobs in batch_blobs.items():
for out, out_blobs in batch_blobs.iteritems():
all_outs[out].extend(out_blobs)
for diff, out_diffs in batch_diffs.items():
for diff, out_diffs in batch_diffs.iteritems():
all_diffs[diff].extend(out_diffs)
# Discard padding at the end.
# Package in ndarray.
for out, diff in zip(all_outs, all_diffs):
all_outs[out] = np.asarray(all_outs[out])
all_diffs[diff] = np.asarray(all_diffs[diff])
# Discard padding at the end and package in ndarray.
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
if pad:
for out in all_outs:
del all_outs[out][-pad:]
for diff in all_diffs:
del all_diffs[diff][-pad:]
for out, diff in zip(all_outs, all_diffs):
all_outs[out] = all_outs[out][:-pad]
all_diffs[diff] = all_diffs[diff][:-pad]
return all_outs, all_diffs


Expand Down Expand Up @@ -253,7 +247,7 @@ def _Net_format_image(self, input_, image):
image: (H x W x K) ndarray

Give
image: (K x H x W) ndarray
image: (1 x K x H x W) ndarray
"""
caf_image = image.astype(np.float32)
input_scale = self.input_scale.get(input_)
Expand Down Expand Up @@ -318,18 +312,21 @@ def _Net_batch(self, blobs):
num = len(blobs.itervalues().next())
batch_size = self.blobs.itervalues().next().num
remainder = num % batch_size
num_batches = (num + remainder) / batch_size
num_batches = num / batch_size

# Yield full batches.
for b in range(num_batches-1):
for b in range(num_batches):
for i in [b * batch_size]:
yield {name: blobs[name][i:i + batch_size] for name in blobs}

# Yield last padded batch, if any.
if remainder > 0:
yield {name: blobs[name][-remainder:] +
[np.zeros_like(blobs[name][0])] * remainder
for name in blobs}
padded_batch = {}
for name in blobs:
padding = np.zeros((remainder,) + blobs[name].shape[1:])
padded_batch[name] = np.concatenate([blobs[name][-remainder:],
padding])
yield padded_batch


# Attach methods to Net.
Expand Down

0 comments on commit 5d584c2

Please sign in to comment.