Skip to content

Commit

Permalink
Merge pull request BVLC#15 from BVLC/master
Browse files Browse the repository at this point in the history
upstream
  • Loading branch information
bittnt authored Jul 11, 2016
2 parents 7e12530 + a1a9d30 commit bb28dad
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 27 deletions.
1 change: 1 addition & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Net {
public:
explicit Net(const NetParameter& param, const Net* root_net = NULL);
explicit Net(const string& param_file, Phase phase,
const int level = 0, const vector<string>* stages = NULL,
const Net* root_net = NULL);
virtual ~Net() {}

Expand Down
22 changes: 14 additions & 8 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ if(UNIX OR APPLE)
endif()

# ---[ Install
file(GLOB files1 *.py requirements.txt)
install(FILES ${files1} DESTINATION python)

file(GLOB files2 caffe/*.py)
install(FILES ${files2} DESTINATION python/caffe)
# scripts
file(GLOB python_files *.py requirements.txt)
install(FILES ${python_files} DESTINATION python)

# module
install(DIRECTORY caffe
DESTINATION python
FILES_MATCHING
PATTERN "*.py"
PATTERN "ilsvrc_2012_mean.npy"
PATTERN "test" EXCLUDE
)

# _caffe.so
install(TARGETS pycaffe DESTINATION python/caffe)
install(DIRECTORY caffe/imagenet caffe/proto caffe/test DESTINATION python/caffe)



45 changes: 37 additions & 8 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,42 @@ void CheckContiguousArray(PyArrayObject* arr, string name,
}
}

// Net constructor for passing phase as int
shared_ptr<Net<Dtype> > Net_Init(
string param_file, int phase) {
CheckFile(param_file);
// Net constructor
shared_ptr<Net<Dtype> > Net_Init(string network_file, int phase,
const int level, const bp::object& stages,
const bp::object& weights) {
CheckFile(network_file);

// Convert stages from list to vector
vector<string> stages_vector;
if (!stages.is_none()) {
for (int i = 0; i < len(stages); i++) {
stages_vector.push_back(bp::extract<string>(stages[i]));
}
}

// Initialize net
shared_ptr<Net<Dtype> > net(new Net<Dtype>(network_file,
static_cast<Phase>(phase), level, &stages_vector));

// Load weights
if (!weights.is_none()) {
std::string weights_file_str = bp::extract<std::string>(weights);
CheckFile(weights_file_str);
net->CopyTrainedLayersFrom(weights_file_str);
}

shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
static_cast<Phase>(phase)));
return net;
}

// Net construct-and-load convenience constructor
// Legacy Net construct-and-load convenience constructor
shared_ptr<Net<Dtype> > Net_Init_Load(
string param_file, string pretrained_param_file, int phase) {
LOG(WARNING) << "DEPRECATION WARNING - deprecated use of Python interface";
LOG(WARNING) << "Use this instead (with the named \"weights\""
<< " parameter):";
LOG(WARNING) << "Net('" << param_file << "', " << phase
<< ", weights='" << pretrained_param_file << "')";
CheckFile(param_file);
CheckFile(pretrained_param_file);

Expand Down Expand Up @@ -266,11 +289,17 @@ BOOST_PYTHON_MODULE(_caffe) {

bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
bp::no_init)
.def("__init__", bp::make_constructor(&Net_Init))
// Constructor
.def("__init__", bp::make_constructor(&Net_Init,
bp::default_call_policies(), (bp::arg("network_file"), "phase",
bp::arg("level")=0, bp::arg("stages")=bp::object(),
bp::arg("weights")=bp::object())))
// Legacy constructor
.def("__init__", bp::make_constructor(&Net_Init_Load))
.def("_forward", &Net<Dtype>::ForwardFromTo)
.def("_backward", &Net<Dtype>::BackwardFromTo)
.def("reshape", &Net<Dtype>::Reshape)
.def("clear_param_diffs", &Net<Dtype>::ClearParamDiffs)
// The cast is to select a particular overload.
.def("copy_from", static_cast<void (Net<Dtype>::*)(const string)>(
&Net<Dtype>::CopyTrainedLayersFrom))
Expand Down
2 changes: 1 addition & 1 deletion python/caffe/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def detect_windows(self, images_windows):
for ix, window_in in enumerate(window_inputs):
caffe_in[ix] = self.transformer.preprocess(in_, window_in)
out = self.forward_all(**{in_: caffe_in})
predictions = out[self.outputs[0]].squeeze(axis=(2, 3))
predictions = out[self.outputs[0]]

# Package predictions with images and windows.
detections = []
Expand Down
32 changes: 27 additions & 5 deletions python/caffe/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
return color


def get_pydot_graph(caffe_net, rankdir, label_edges=True):
def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
"""Create a data structure which represents the `caffe_net`.
Parameters
Expand All @@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
Direction of graph layout.
label_edges : boolean, optional
Label the edges (default is True).
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
Expand All @@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
pydot_nodes = {}
pydot_edges = []
for layer in caffe_net.layer:
if phase is not None:
included = False
if len(layer.include) == 0:
included = True
if len(layer.include) > 0 and len(layer.exclude) > 0:
raise ValueError('layer ' + layer.name + ' has both include '
'and exclude specified.')
for layer_phase in layer.include:
included = included or layer_phase.phase == phase
for layer_phase in layer.exclude:
included = included and not layer_phase.phase == phase
if not included:
continue
node_label = get_layer_label(layer, rankdir)
node_name = "%s_%s" % (layer.name, layer.type)
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
Expand Down Expand Up @@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
return pydot_graph


def draw_net(caffe_net, rankdir, ext='png'):
def draw_net(caffe_net, rankdir, ext='png', phase=None):
"""Draws a caffe net and returns the image string encoded using the given
extension.
Expand All @@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
ext : string, optional
The image extension (the default is 'png').
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
string :
Postscript representation of the graph.
"""
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)


def draw_net_to_file(caffe_net, filename, rankdir='LR'):
def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
"""Draws a caffe net, and saves it to file using the format given as the
file extension. Use '.raw' to output raw text that you can manually feed
to graphviz to draw graphs.
Expand All @@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
The path to a file where the networks visualization will be stored.
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, rankdir, ext))
fid.write(draw_net(caffe_net, rankdir, ext, phase))
Loading

0 comments on commit bb28dad

Please sign in to comment.