Skip to content

Commit

Permalink
python Net.backward() helper and Net.BackwardPrefilled()
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed May 14, 2014
1 parent 9d4324e commit ac5e6fa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ struct CaffeNet {
net_->ForwardPrefilled();
}

void BackwardPrefilled() {
net_->Backward();
}

void set_input_arrays(object data_obj, object labels_obj) {
// check that this network has an input MemoryDataLayer
shared_ptr<MemoryDataLayer<float> > md_layer =
Expand Down Expand Up @@ -411,6 +415,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
.def("BackwardPrefilled", &CaffeNet::BackwardPrefilled)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
Expand Down
37 changes: 35 additions & 2 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def _Net_forward(self, **kwargs):
If None, input is taken from data layers by ForwardPrefilled().
Give
out: {output blob name: list of output blobs} dict.
outs: {output blob name: list of output blobs} dict.
"""
outs = {}
if not kwargs:
# Carry out prefilled forward pass and unpack output.
self.ForwardPrefilled()
Expand All @@ -70,6 +69,7 @@ def _Net_forward(self, **kwargs):
self.Forward(in_blobs, out_blobs)

# Unpack output blobs
outs = {}
for out, out_blob in zip(self.outputs, out_blobs):
outs[out] = [out_blob[ix, :, :, :].squeeze()
for ix in range(out_blob.shape[0])]
Expand All @@ -78,6 +78,39 @@ def _Net_forward(self, **kwargs):
Net.forward = _Net_forward


def _Net_backward(self, **kwargs):
"""
Backward pass: prepare diffs and run the net backward.
Take
kwargs: Keys are output blob names and values are lists of diffs.
If None, input is taken from data layers by BackwardPrefilled().
Give
bottom_diffs: {input blob name: list of diffs} dict.
"""
if not kwargs:
self.BackwardPrefilled()
bottom_diffs = [self.blobs[in_].diff for in_ in self.inputs]
else:
# Create top and bottom diffs according to net defined shapes
# and make arrays single and C-contiguous as Caffe expects.
top_diffs = [np.ascontiguousarray(np.concatenate(kwargs[out]),
dtype=np.float32) for out in self.outputs]
bottom_diffs = [np.empty(self.blobs[bottom].data.shape, dtype=np.float32)
for bottom in self.inputs]
self.Backward(top_diffs, bottom_diffs)

# Unpack bottom diffs
bottom_diffs = {}
for bottom, bottom_diff in zip(self.inputs, bottom_diffs):
bottom_diffs[bottom] = [bottom_diff[ix, :, :, :].squeeze()
for ix in range(bottom_diff.shape[0])]
return bottom_diffs

Net.backward = _Net_backward


def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.
Expand Down

0 comments on commit ac5e6fa

Please sign in to comment.