From ac87850887f064752c2ad815367484c07eaf5449 Mon Sep 17 00:00:00 2001 From: Marco Castelluccio Date: Wed, 26 Aug 2015 19:03:59 -0700 Subject: [PATCH 1/8] No need to squeeze the output of the network --- python/caffe/detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/caffe/detector.py b/python/caffe/detector.py index 75cd3b1202f..ef1f91730bf 100644 --- a/python/caffe/detector.py +++ b/python/caffe/detector.py @@ -83,7 +83,7 @@ def detect_windows(self, images_windows): for ix, window_in in enumerate(window_inputs): caffe_in[ix] = self.transformer.preprocess(in_, window_in) out = self.forward_all(**{in_: caffe_in}) - predictions = out[self.outputs[0]].squeeze(axis=(2, 3)) + predictions = out[self.outputs[0]] # Package predictions with images and windows. detections = [] From d167e61a23a54de529d51731fbe543ff4cec0d3c Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Wed, 1 Jun 2016 09:50:57 -0700 Subject: [PATCH 2/8] Add level and stages to Net constructor This internal functionality will be exposed through the various interfaces in subsequent commits Also adds C++ tests for all-in-one nets --- include/caffe/net.hpp | 1 + src/caffe/net.cpp | 11 +++- src/caffe/test/test_net.cpp | 128 ++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 1 deletion(-) diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 0addb3c2a6d..493bdf294e2 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -25,6 +25,7 @@ class Net { public: explicit Net(const NetParameter& param, const Net* root_net = NULL); explicit Net(const string& param_file, Phase phase, + const int level = 0, const vector* stages = NULL, const Net* root_net = NULL); virtual ~Net() {} diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index f0bf594936c..644cb7e97ee 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -28,11 +28,20 @@ Net::Net(const NetParameter& param, const Net* root_net) } template -Net::Net(const string& param_file, Phase phase, const Net* root_net) +Net::Net(const string& param_file, Phase phase, + const int level, const vector* stages, + const Net* root_net) : root_net_(root_net) { NetParameter param; ReadNetParamsFromTextFileOrDie(param_file, ¶m); + // Set phase, stages and level param.mutable_state()->set_phase(phase); + if (stages != NULL) { + for (int i = 0; i < stages->size(); i++) { + param.mutable_state()->add_stage((*stages)[i]); + } + } + param.mutable_state()->set_level(level); Init(param); } diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 92fd317fee8..24b957f2acc 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -9,6 +9,7 @@ #include "caffe/common.hpp" #include "caffe/filler.hpp" #include "caffe/net.hpp" +#include "caffe/util/io.hpp" #include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -29,6 +30,17 @@ class NetTest : public MultiDeviceTest { net_.reset(new Net(param)); } + virtual void InitNetFromProtoFileWithState(const string& proto, + Phase phase = caffe::TRAIN, const int level = 0, + const vector* stages = NULL) { + NetParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + string param_file; + MakeTempFilename(¶m_file); + WriteProtoToTextFile(param, param_file); + net_.reset(new Net(param_file, phase, level, stages)); + } + virtual void CopyNetBlobs(const bool copy_diff, vector > >* blobs_copy) { CHECK(net_); @@ -771,6 +783,62 @@ class NetTest : public MultiDeviceTest { InitNetFromProtoString(proto); } + virtual void InitAllInOneNet(Phase phase = caffe::TRAIN, + const int level = 0, const vector* stages = NULL) { + string proto = + "name: 'All-in-one Network'" + "layer { " + " name: 'train-data' " + " type: 'DummyData' " + " top: 'data' " + " top: 'label' " + " dummy_data_param { " + " shape { dim: 1 dim: 10 } " + " shape { dim: 1 dim: 1 } " + " } " + " include { phase: TRAIN stage: 'train' } " + "} " + "layer { " + " name: 'val-data' " + " type: 'DummyData' " + " top: 'data' " + " top: 'label' " + " dummy_data_param { " + " shape { dim: 1 dim: 10 } " + " shape { dim: 1 dim: 1 } " + " } " + " include { phase: TEST stage: 'val' } " + "} " + "layer { " + " name: 'deploy-data' " + " type: 'Input' " + " top: 'data' " + " input_param { " + " shape { dim: 1 dim: 10 } " + " } " + " include { phase: TEST stage: 'deploy' } " + "} " + "layer { " + " name: 'ip' " + " type: 'InnerProduct' " + " bottom: 'data' " + " top: 'ip' " + " inner_product_param { " + " num_output: 2 " + " } " + "} " + "layer { " + " name: 'loss' " + " type: 'SoftmaxWithLoss' " + " bottom: 'ip' " + " bottom: 'label' " + " top: 'loss' " + " include { phase: TRAIN stage: 'train' } " + " include { phase: TEST stage: 'val' } " + "} "; + InitNetFromProtoFileWithState(proto, phase, level, stages); + } + int seed_; shared_ptr > net_; }; @@ -2473,4 +2541,64 @@ TYPED_TEST(NetTest, TestForcePropagateDown) { } } +TYPED_TEST(NetTest, TestAllInOneNetTrain) { + vector stages; + stages.push_back("train"); + this->InitAllInOneNet(caffe::TRAIN, 0, &stages); + bool found_data = false; + bool found_loss = false; + for (int i = 0; i < this->net_->layers().size(); ++i) { + const string& layer_name = this->net_->layer_names()[i]; + if (layer_name == "train-data") { + found_data = true; + } else if (layer_name == "loss") { + found_loss = true; + } else { + ASSERT_NE(layer_name, "val-data"); + ASSERT_NE(layer_name, "deploy-data"); + } + } + ASSERT_TRUE(found_data); + ASSERT_TRUE(found_loss); +} + +TYPED_TEST(NetTest, TestAllInOneNetVal) { + vector stages; + stages.push_back("val"); + this->InitAllInOneNet(caffe::TEST, 0, &stages); + bool found_data = false; + bool found_loss = false; + for (int i = 0; i < this->net_->layers().size(); ++i) { + const string& layer_name = this->net_->layer_names()[i]; + if (layer_name == "val-data") { + found_data = true; + } else if (layer_name == "loss") { + found_loss = true; + } else { + ASSERT_NE(layer_name, "train-data"); + ASSERT_NE(layer_name, "deploy-data"); + } + } + ASSERT_TRUE(found_data); + ASSERT_TRUE(found_loss); +} + +TYPED_TEST(NetTest, TestAllInOneNetDeploy) { + vector stages; + stages.push_back("deploy"); + this->InitAllInOneNet(caffe::TEST, 0, &stages); + bool found_data = false; + for (int i = 0; i < this->net_->layers().size(); ++i) { + const string& layer_name = this->net_->layer_names()[i]; + if (layer_name == "deploy-data") { + found_data = true; + } else { + ASSERT_NE(layer_name, "train-data"); + ASSERT_NE(layer_name, "val-data"); + ASSERT_NE(layer_name, "loss"); + } + } + ASSERT_TRUE(found_data); +} + } // namespace caffe From 66e84d785a72d66511bffe30c0f016af9103deb8 Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Wed, 1 Jun 2016 09:56:51 -0700 Subject: [PATCH 3/8] Add phase, level and stages to tools/caffe Adds command-line flags for phase, level and stage train -- override level and stages for test_state from solver test -- set level and stages time -- set phase, level and stages --- tools/caffe.cpp | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 5bb60eb161d..9bf4214ad93 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -34,6 +34,13 @@ DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", "The model definition protocol buffer text file."); +DEFINE_string(phase, "", + "Optional; network phase (TRAIN or TEST). Only used for 'time'."); +DEFINE_int32(level, 0, + "Optional; network level."); +DEFINE_string(stage, "", + "Optional; network stages (not to be confused with phase), " + "separated by ','."); DEFINE_string(snapshot, "", "Optional; the snapshot solver state to resume training."); DEFINE_string(weights, "", @@ -101,6 +108,25 @@ static void get_gpus(vector* gpus) { } } +// Parse phase from flags +caffe::Phase get_phase_from_flags(caffe::Phase default_value) { + if (FLAGS_phase == "") + return default_value; + if (FLAGS_phase == "TRAIN") + return caffe::TRAIN; + if (FLAGS_phase == "TEST") + return caffe::TEST; + LOG(FATAL) << "phase must be \"TRAIN\" or \"TEST\""; + return caffe::TRAIN; // Avoid warning +} + +// Parse stages from flags +vector get_stages_from_flags() { + vector stages; + boost::split(stages, FLAGS_stage, boost::is_any_of(",")); + return stages; +} + // caffe commands to call by // caffe // @@ -156,10 +182,16 @@ int train() { CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) << "Give a snapshot to resume training or weights to finetune " "but not both."; + vector stages = get_stages_from_flags(); caffe::SolverParameter solver_param; caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param); + solver_param.mutable_train_state()->set_level(FLAGS_level); + for (int i = 0; i < stages.size(); i++) { + solver_param.mutable_train_state()->add_stage(stages[i]); + } + // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. if (FLAGS_gpu.size() == 0 @@ -229,6 +261,7 @@ RegisterBrewFunction(train); int test() { CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to score."; CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; + vector stages = get_stages_from_flags(); // Set device id and mode vector gpus; @@ -247,7 +280,7 @@ int test() { Caffe::set_mode(Caffe::CPU); } // Instantiate the caffe net. - Net caffe_net(FLAGS_model, caffe::TEST); + Net caffe_net(FLAGS_model, caffe::TEST, FLAGS_level, &stages); caffe_net.CopyTrainedLayersFrom(FLAGS_weights); LOG(INFO) << "Running for " << FLAGS_iterations << " iterations."; @@ -300,6 +333,8 @@ RegisterBrewFunction(test); // Time: benchmark the execution time of a model. int time() { CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; + caffe::Phase phase = get_phase_from_flags(caffe::TRAIN); + vector stages = get_stages_from_flags(); // Set device id and mode vector gpus; @@ -313,7 +348,7 @@ int time() { Caffe::set_mode(Caffe::CPU); } // Instantiate the caffe net. - Net caffe_net(FLAGS_model, caffe::TRAIN); + Net caffe_net(FLAGS_model, phase, FLAGS_level, &stages); // Do a clean forward and backward pass, so that memory allocation are done // and future iterations will be more stable. From 19adc7a79e3acacc777076143357cc0569781cd3 Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Wed, 1 Jun 2016 10:02:41 -0700 Subject: [PATCH 4/8] Add level and stages to pycaffe Uses Boost.Python's pattern matching to differentiate between constructors Also adds Python tests for all-in-one nets --- python/caffe/_caffe.cpp | 44 +++++-- python/caffe/test/test_net.py | 228 +++++++++++++++++++++++++++++++++- 2 files changed, 263 insertions(+), 9 deletions(-) diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 48a0c8f2e95..e2726286dfb 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -86,19 +86,42 @@ void CheckContiguousArray(PyArrayObject* arr, string name, } } -// Net constructor for passing phase as int -shared_ptr > Net_Init( - string param_file, int phase) { - CheckFile(param_file); +// Net constructor +shared_ptr > Net_Init(string network_file, int phase, + const int level, const bp::object& stages, + const bp::object& weights) { + CheckFile(network_file); + + // Convert stages from list to vector + vector stages_vector; + if (!stages.is_none()) { + for (int i = 0; i < len(stages); i++) { + stages_vector.push_back(bp::extract(stages[i])); + } + } + + // Initialize net + shared_ptr > net(new Net(network_file, + static_cast(phase), level, &stages_vector)); + + // Load weights + if (!weights.is_none()) { + std::string weights_file_str = bp::extract(weights); + CheckFile(weights_file_str); + net->CopyTrainedLayersFrom(weights_file_str); + } - shared_ptr > net(new Net(param_file, - static_cast(phase))); return net; } -// Net construct-and-load convenience constructor +// Legacy Net construct-and-load convenience constructor shared_ptr > Net_Init_Load( string param_file, string pretrained_param_file, int phase) { + LOG(WARNING) << "DEPRECATION WARNING - deprecated use of Python interface"; + LOG(WARNING) << "Use this instead (with the named \"weights\"" + << " parameter):"; + LOG(WARNING) << "Net('" << param_file << "', " << phase + << ", weights='" << pretrained_param_file << "')"; CheckFile(param_file); CheckFile(pretrained_param_file); @@ -245,7 +268,12 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_, shared_ptr >, boost::noncopyable >("Net", bp::no_init) - .def("__init__", bp::make_constructor(&Net_Init)) + // Constructor + .def("__init__", bp::make_constructor(&Net_Init, + bp::default_call_policies(), (bp::arg("network_file"), "phase", + bp::arg("level")=0, bp::arg("stages")=bp::object(), + bp::arg("weights")=bp::object()))) + // Legacy constructor .def("__init__", bp::make_constructor(&Net_Init_Load)) .def("_forward", &Net::ForwardFromTo) .def("_backward", &Net::BackwardFromTo) diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py index 4cacfcd05bb..300aabdeea5 100644 --- a/python/caffe/test/test_net.py +++ b/python/caffe/test/test_net.py @@ -72,7 +72,11 @@ def test_save_and_read(self): f.close() self.net.save(f.name) net_file = simple_net_file(self.num_output) - net2 = caffe.Net(net_file, f.name, caffe.TRAIN) + # Test legacy constructor + # should print deprecation warning + caffe.Net(net_file, f.name, caffe.TRAIN) + # Test named constructor + net2 = caffe.Net(net_file, caffe.TRAIN, weights=f.name) os.remove(net_file) os.remove(f.name) for name in self.net.params: @@ -93,3 +97,225 @@ def test_save_hdf5(self): for i in range(len(self.net.params[name])): self.assertEqual(abs(self.net.params[name][i].data - net2.params[name][i].data).sum(), 0) + +class TestLevels(unittest.TestCase): + + TEST_NET = """ +layer { + name: "data" + type: "DummyData" + top: "data" + dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } } +} +layer { + name: "NoLevel" + type: "InnerProduct" + bottom: "data" + top: "NoLevel" + inner_product_param { num_output: 1 } +} +layer { + name: "Level0Only" + type: "InnerProduct" + bottom: "data" + top: "Level0Only" + include { min_level: 0 max_level: 0 } + inner_product_param { num_output: 1 } +} +layer { + name: "Level1Only" + type: "InnerProduct" + bottom: "data" + top: "Level1Only" + include { min_level: 1 max_level: 1 } + inner_product_param { num_output: 1 } +} +layer { + name: "Level>=0" + type: "InnerProduct" + bottom: "data" + top: "Level>=0" + include { min_level: 0 } + inner_product_param { num_output: 1 } +} +layer { + name: "Level>=1" + type: "InnerProduct" + bottom: "data" + top: "Level>=1" + include { min_level: 1 } + inner_product_param { num_output: 1 } +} +""" + + def setUp(self): + self.f = tempfile.NamedTemporaryFile(mode='w+') + self.f.write(self.TEST_NET) + self.f.flush() + + def tearDown(self): + self.f.close() + + def check_net(self, net, blobs): + net_blobs = [b for b in net.blobs.keys() if 'data' not in b] + self.assertEqual(net_blobs, blobs) + + def test_0(self): + net = caffe.Net(self.f.name, caffe.TEST) + self.check_net(net, ['NoLevel', 'Level0Only', 'Level>=0']) + + def test_1(self): + net = caffe.Net(self.f.name, caffe.TEST, level=1) + self.check_net(net, ['NoLevel', 'Level1Only', 'Level>=0', 'Level>=1']) + + +class TestStages(unittest.TestCase): + + TEST_NET = """ +layer { + name: "data" + type: "DummyData" + top: "data" + dummy_data_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } } +} +layer { + name: "A" + type: "InnerProduct" + bottom: "data" + top: "A" + include { stage: "A" } + inner_product_param { num_output: 1 } +} +layer { + name: "B" + type: "InnerProduct" + bottom: "data" + top: "B" + include { stage: "B" } + inner_product_param { num_output: 1 } +} +layer { + name: "AorB" + type: "InnerProduct" + bottom: "data" + top: "AorB" + include { stage: "A" } + include { stage: "B" } + inner_product_param { num_output: 1 } +} +layer { + name: "AandB" + type: "InnerProduct" + bottom: "data" + top: "AandB" + include { stage: "A" stage: "B" } + inner_product_param { num_output: 1 } +} +""" + + def setUp(self): + self.f = tempfile.NamedTemporaryFile(mode='w+') + self.f.write(self.TEST_NET) + self.f.flush() + + def tearDown(self): + self.f.close() + + def check_net(self, net, blobs): + net_blobs = [b for b in net.blobs.keys() if 'data' not in b] + self.assertEqual(net_blobs, blobs) + + def test_A(self): + net = caffe.Net(self.f.name, caffe.TEST, stages=['A']) + self.check_net(net, ['A', 'AorB']) + + def test_B(self): + net = caffe.Net(self.f.name, caffe.TEST, stages=['B']) + self.check_net(net, ['B', 'AorB']) + + def test_AandB(self): + net = caffe.Net(self.f.name, caffe.TEST, stages=['A', 'B']) + self.check_net(net, ['A', 'B', 'AorB', 'AandB']) + + +class TestAllInOne(unittest.TestCase): + + TEST_NET = """ +layer { + name: "train_data" + type: "DummyData" + top: "data" + top: "label" + dummy_data_param { + shape { dim: 1 dim: 1 dim: 10 dim: 10 } + shape { dim: 1 dim: 1 dim: 1 dim: 1 } + } + include { phase: TRAIN stage: "train" } +} +layer { + name: "val_data" + type: "DummyData" + top: "data" + top: "label" + dummy_data_param { + shape { dim: 1 dim: 1 dim: 10 dim: 10 } + shape { dim: 1 dim: 1 dim: 1 dim: 1 } + } + include { phase: TEST stage: "val" } +} +layer { + name: "deploy_data" + type: "Input" + top: "data" + input_param { shape { dim: 1 dim: 1 dim: 10 dim: 10 } } + include { phase: TEST stage: "deploy" } +} +layer { + name: "ip" + type: "InnerProduct" + bottom: "data" + top: "ip" + inner_product_param { num_output: 2 } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "ip" + bottom: "label" + top: "loss" + include: { phase: TRAIN stage: "train" } + include: { phase: TEST stage: "val" } +} +layer { + name: "pred" + type: "Softmax" + bottom: "ip" + top: "pred" + include: { phase: TEST stage: "deploy" } +} +""" + + def setUp(self): + self.f = tempfile.NamedTemporaryFile(mode='w+') + self.f.write(self.TEST_NET) + self.f.flush() + + def tearDown(self): + self.f.close() + + def check_net(self, net, outputs): + self.assertEqual(list(net.blobs['data'].shape), [1,1,10,10]) + self.assertEqual(net.outputs, outputs) + + def test_train(self): + net = caffe.Net(self.f.name, caffe.TRAIN, stages=['train']) + self.check_net(net, ['loss']) + + def test_val(self): + net = caffe.Net(self.f.name, caffe.TEST, stages=['val']) + self.check_net(net, ['loss']) + + def test_deploy(self): + net = caffe.Net(self.f.name, caffe.TEST, stages=['deploy']) + self.check_net(net, ['pred']) + From 118c97ff5890e92b9aa603d925d947d45086b330 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Tue, 21 Jun 2016 17:37:55 -0700 Subject: [PATCH 5/8] add clear_param_diffs to the python net interface --- python/caffe/_caffe.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 334088e8a57..a7fb886aa06 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -271,6 +271,7 @@ BOOST_PYTHON_MODULE(_caffe) { .def("_forward", &Net::ForwardFromTo) .def("_backward", &Net::BackwardFromTo) .def("reshape", &Net::Reshape) + .def("clear_param_diffs", &Net::ClearParamDiffs) // The cast is to select a particular overload. .def("copy_from", static_cast::*)(const string)>( &Net::CopyTrainedLayersFrom)) From 892c78dd7833f1818a76d4025076b34946200fa0 Mon Sep 17 00:00:00 2001 From: Valentin Tolmer Date: Tue, 21 Jun 2016 17:42:31 -0700 Subject: [PATCH 6/8] add unit test for clear_param_diffs --- python/caffe/test/test_net.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py index 4cacfcd05bb..7fb9f475d43 100644 --- a/python/caffe/test/test_net.py +++ b/python/caffe/test/test_net.py @@ -63,6 +63,17 @@ def test_forward_backward(self): self.net.forward() self.net.backward() + def test_clear_param_diffs(self): + # Run a forward/backward step to have non-zero diffs + self.net.forward() + self.net.backward() + diff = self.net.params["conv"][0].diff + # Check that we have non-zero diffs + self.assertTrue(diff.max() > 0) + self.net.clear_param_diffs() + # Check that the diffs are now 0 + self.assertTrue((diff == 0).all()) + def test_inputs_outputs(self): self.assertEqual(self.net.inputs, []) self.assertEqual(self.net.outputs, ['loss']) From f0b1a9e770594f93fecda9e876faaafaede2b496 Mon Sep 17 00:00:00 2001 From: Carl Doersch Date: Sun, 3 Jul 2016 12:32:19 -0700 Subject: [PATCH 7/8] Add phase support for draw net --- python/caffe/draw.py | 32 +++++++++++++++++++++++++++----- python/draw_net.py | 15 ++++++++++++++- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/python/caffe/draw.py b/python/caffe/draw.py index 61205ca9f37..9eecf6d7b46 100644 --- a/python/caffe/draw.py +++ b/python/caffe/draw.py @@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype): return color -def get_pydot_graph(caffe_net, rankdir, label_edges=True): +def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None): """Create a data structure which represents the `caffe_net`. Parameters @@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True): Direction of graph layout. label_edges : boolean, optional Label the edges (default is True). + phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional + Include layers from this network phase. If None, include all layers. + (the default is None) Returns ------- @@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True): pydot_nodes = {} pydot_edges = [] for layer in caffe_net.layer: + if phase is not None: + included = False + if len(layer.include) == 0: + included = True + if len(layer.include) > 0 and len(layer.exclude) > 0: + raise ValueError('layer ' + layer.name + ' has both include ' + 'and exclude specified.') + for layer_phase in layer.include: + included = included or layer_phase.phase == phase + for layer_phase in layer.exclude: + included = included and not layer_phase.phase == phase + if not included: + continue node_label = get_layer_label(layer, rankdir) node_name = "%s_%s" % (layer.name, layer.type) if (len(layer.bottom) == 1 and len(layer.top) == 1 and @@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True): return pydot_graph -def draw_net(caffe_net, rankdir, ext='png'): +def draw_net(caffe_net, rankdir, ext='png', phase=None): """Draws a caffe net and returns the image string encoded using the given extension. @@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'): caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer. ext : string, optional The image extension (the default is 'png'). + phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional + Include layers from this network phase. If None, include all layers. + (the default is None) Returns ------- string : Postscript representation of the graph. """ - return get_pydot_graph(caffe_net, rankdir).create(format=ext) + return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext) -def draw_net_to_file(caffe_net, filename, rankdir='LR'): +def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None): """Draws a caffe net, and saves it to file using the format given as the file extension. Use '.raw' to output raw text that you can manually feed to graphviz to draw graphs. @@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'): The path to a file where the networks visualization will be stored. rankdir : {'LR', 'TB', 'BT'} Direction of graph layout. + phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional + Include layers from this network phase. If None, include all layers. + (the default is None) """ ext = filename[filename.rfind('.')+1:] with open(filename, 'wb') as fid: - fid.write(draw_net(caffe_net, rankdir, ext)) + fid.write(draw_net(caffe_net, rankdir, ext, phase)) diff --git a/python/draw_net.py b/python/draw_net.py index ec76a744da3..dfe70d26a71 100755 --- a/python/draw_net.py +++ b/python/draw_net.py @@ -28,6 +28,11 @@ def parse_args(): 'http://www.graphviz.org/doc/info/' 'attrs.html#k:rankdir'), default='LR') + parser.add_argument('--phase', + help=('Which network phase to draw: can be TRAIN, ' + 'TEST, or ALL. If ALL, then all layers are drawn ' + 'regardless of phase.'), + default="ALL") args = parser.parse_args() return args @@ -38,7 +43,15 @@ def main(): net = caffe_pb2.NetParameter() text_format.Merge(open(args.input_net_proto_file).read(), net) print('Drawing net to %s' % args.output_image_file) - caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir) + phase=None; + if args.phase == "TRAIN": + phase = caffe.TRAIN + elif args.phase == "TEST": + phase = caffe.TEST + elif args.phase != "ALL": + raise ValueError("Unknown phase: " + args.phase) + caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir, + phase) if __name__ == '__main__': From f9fd20ea3893c515b19cae6fa3693b1649fb9487 Mon Sep 17 00:00:00 2001 From: Luke Yeager Date: Fri, 8 Jul 2016 12:05:17 -0700 Subject: [PATCH 8/8] Fix Python installation with CMake install target --- python/CMakeLists.txt | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index a22641401f0..bf492a24b1c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -22,13 +22,19 @@ if(UNIX OR APPLE) endif() # ---[ Install -file(GLOB files1 *.py requirements.txt) -install(FILES ${files1} DESTINATION python) - -file(GLOB files2 caffe/*.py) -install(FILES ${files2} DESTINATION python/caffe) +# scripts +file(GLOB python_files *.py requirements.txt) +install(FILES ${python_files} DESTINATION python) + +# module +install(DIRECTORY caffe + DESTINATION python + FILES_MATCHING + PATTERN "*.py" + PATTERN "ilsvrc_2012_mean.npy" + PATTERN "test" EXCLUDE + ) + +# _caffe.so install(TARGETS pycaffe DESTINATION python/caffe) -install(DIRECTORY caffe/imagenet caffe/proto caffe/test DESTINATION python/caffe) - -