diff --git a/Makefile b/Makefile index 5424c3a1858..568d9c2774d 100644 --- a/Makefile +++ b/Makefile @@ -272,7 +272,7 @@ endif ifeq ($(OSX), 1) CXX := /usr/bin/clang++ ifneq ($(CPU_ONLY), 1) - CUDA_VERSION := $(shell $(CUDA_DIR)/bin/nvcc -V | grep -o 'release \d' | grep -o '\d') + CUDA_VERSION := $(shell $(CUDA_DIR)/bin/nvcc -V | grep -o 'release [0-9.]*' | grep -o '[0-9.]*') ifeq ($(shell echo | awk '{exit $(CUDA_VERSION) < 7.0;}'), 1) CXXFLAGS += -stdlib=libstdc++ LINKFLAGS += -stdlib=libstdc++ diff --git a/Makefile.config.example b/Makefile.config.example index 8fd49c9c1a7..07bed63ae40 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -98,6 +98,7 @@ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.) # USE_PKG_CONFIG := 1 +# N.B. both build and distribute dirs are cleared on `make clean` BUILD_DIR := build DISTRIBUTE_DIR := distribute diff --git a/docker/Makefile b/docker/Makefile index 725208c6b2b..0de887d0e19 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -22,7 +22,7 @@ docker_files: standalone_files standalone_files: standalone/cpu/Dockerfile standalone/gpu/Dockerfile -FROM_GPU = "nvidia/cuda:cudnn" +FROM_GPU = "nvidia/cuda:7.5-cudnn4-devel-ubuntu14.04" FROM_CPU = "ubuntu:14.04" GPU_CMAKE_ARGS = -DUSE_CUDNN=1 CPU_CMAKE_ARGS = -DCPU_ONLY=1 diff --git a/docker/standalone/gpu/Dockerfile b/docker/standalone/gpu/Dockerfile index 1ddc6560d16..371aad5b1e9 100644 --- a/docker/standalone/gpu/Dockerfile +++ b/docker/standalone/gpu/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:cudnn +FROM nvidia/cuda:7.5-cudnn4-devel-ubuntu14.04 MAINTAINER caffe-maint@googlegroups.com RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/docs/installation.md b/docs/installation.md index 893164584d9..1e29a49d82d 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -5,13 +5,23 @@ title: Installation # Installation Prior to installing, have a glance through this guide and take note of the details for your platform. -We install and run Caffe on Ubuntu 14.04 and 12.04, OS X 10.10 / 10.9 / 10.8, and AWS. -The official Makefile and `Makefile.config` build are complemented by an automatic CMake build from the community. +We install and run Caffe on Ubuntu 16.04–12.04, OS X 10.11–10.8, and through Docker and AWS. +The official Makefile and `Makefile.config` build are complemented by a [community CMake build](#cmake-build). + +**Step-by-step Instructions**: + +- [Docker setup](https://github.com/BVLC/caffe/tree/master/docker) *out-of-the-box brewing* +- [Ubuntu installation](install_apt.html) *the standard platform* +- [OS X installation](install_osx.html) +- [RHEL / CentOS / Fedora installation](install_yum.html) +- [Windows](https://github.com/BVLC/caffe/tree/windows) *see the Windows branch led by Microsoft* +- [OpenCL](https://github.com/BVLC/caffe/tree/opencl) *see the OpenCL branch led by Fabian Tschopp* + +**Overview**: - [Prerequisites](#prerequisites) - [Compilation](#compilation) - [Hardware](#hardware) -- Platforms: [Ubuntu guide](install_apt.html), [OS X guide](install_osx.html), and [RHEL / CentOS / Fedora guide](install_yum.html) When updating Caffe, it's best to `make clean` before re-compiling. @@ -20,7 +30,7 @@ When updating Caffe, it's best to `make clean` before re-compiling. Caffe has several dependencies: * [CUDA](https://developer.nvidia.com/cuda-zone) is required for GPU mode. - * library version 7.0 and the latest driver version are recommended, but 6.* is fine too + * library version 7+ and the latest driver version are recommended, but 6.* is fine too * 5.5, and 5.0 are compatible but considered legacy * [BLAS](http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) via ATLAS, MKL, or OpenBLAS. * [Boost](http://www.boost.org/) >= 1.55 @@ -30,14 +40,14 @@ Optional dependencies: * [OpenCV](http://opencv.org/) >= 2.4 including 3.0 * IO libraries: `lmdb`, `leveldb` (note: leveldb requires `snappy`) -* cuDNN for GPU acceleration (v3) +* cuDNN for GPU acceleration (v4) Pycaffe and Matcaffe interfaces have their own natural needs. * For Python Caffe: `Python 2.7` or `Python 3.3+`, `numpy (>= 1.7)`, boost-provided `boost.python` * For MATLAB Caffe: MATLAB with the `mex` compiler. -**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic. The current version is cuDNN v3; older versions are supported in older Caffe. +**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic. The current version is cuDNN v4; older versions are supported in older Caffe. **CPU-only Caffe**: for cold-brewed CPU-only Caffe uncomment the `CPU_ONLY := 1` flag in `Makefile.config` to configure and build Caffe without CUDA. This is helpful for cloud or cluster deployment. @@ -82,10 +92,6 @@ Install MATLAB, and make sure that its `mex` is in your `$PATH`. *Caffe's MATLAB interface works with versions 2015a, 2014a/b, 2013a/b, and 2012b.* -#### Windows - -There is an unofficial Windows port of Caffe at [niuzhiheng/caffe:windows](https://github.com/niuzhiheng/caffe). Thanks [@niuzhiheng](https://github.com/niuzhiheng)! - ## Compilation Caffe can be compiled with either Make or CMake. Make is officially supported while CMake is supported by the community. @@ -113,7 +119,7 @@ Be sure to set your MATLAB and Python paths in `Makefile.config` first! Now that you have installed Caffe, check out the [MNIST tutorial](gathered/examples/mnist.html) and the [reference ImageNet model tutorial](gathered/examples/imagenet.html). -### Compilation with CMake +### CMake Build In lieu of manually editing `Makefile.config` to configure the build, Caffe offers an unofficial CMake build thanks to @Nerei, @akosiorek, and other members of the community. It requires CMake version >= 2.8.7. The basic steps are as follows: @@ -129,9 +135,9 @@ See [PR #1667](https://github.com/BVLC/caffe/pull/1667) for options and details. ## Hardware -**Laboratory Tested Hardware**: Berkeley Vision runs Caffe with K40s, K20s, and Titans including models at ImageNet/ILSVRC scale. We also run on GTX series cards (980s and 770s) and GPU-equipped MacBook Pros. We have not encountered any trouble in-house with devices with CUDA capability >= 3.0. All reported hardware issues thus-far have been due to GPU configuration, overheating, and the like. +**Laboratory Tested Hardware**: Berkeley Vision runs Caffe with Titan Xs, K80s, GTX 980s, K40s, K20s, Titans, and GTX 770s including models at ImageNet/ILSVRC scale. We have not encountered any trouble in-house with devices with CUDA capability >= 3.0. All reported hardware issues thus-far have been due to GPU configuration, overheating, and the like. -**CUDA compute capability**: devices with compute capability <= 2.0 may have to reduce CUDA thread numbers and batch sizes due to hardware constraints. Your mileage may vary. +**CUDA compute capability**: devices with compute capability <= 2.0 may have to reduce CUDA thread numbers and batch sizes due to hardware constraints. Brew with caution; we recommend compute capability >= 3.0. Once installed, check your times against our [reference performance numbers](performance_hardware.html) to make sure everything is configured properly. diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index e1b89f42fb6..7385a74a679 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -91,6 +91,8 @@ void convert_dataset(const string& input_folder, const string& output_folder, } int main(int argc, char** argv) { + FLAGS_alsologtostderr = 1; + if (argc != 4) { printf("This script converts the CIFAR dataset to the leveldb format used\n" "by caffe to perform classification.\n" diff --git a/examples/cpp_classification/readme.md b/examples/cpp_classification/readme.md index a086db1a035..0de2885b53c 100644 --- a/examples/cpp_classification/readme.md +++ b/examples/cpp_classification/readme.md @@ -42,7 +42,7 @@ script: The ImageNet labels file (also called the *synset file*) is also required in order to map a prediction to the name of the class: ``` -./data/ilsvrc12/get_ilsvrc_aux.sh. +./data/ilsvrc12/get_ilsvrc_aux.sh ``` Using the files that were downloaded, we can classify the provided cat image (`examples/images/cat.jpg`) using this command: diff --git a/examples/finetune_flickr_style/readme.md b/examples/finetune_flickr_style/readme.md index 9ba4c9217ff..188dedf1b9a 100644 --- a/examples/finetune_flickr_style/readme.md +++ b/examples/finetune_flickr_style/readme.md @@ -57,7 +57,11 @@ The prototxts in this example assume this, and also assume the presence of the I We'll also need the ImageNet-trained model, which you can obtain by running `./scripts/download_model_binary.py models/bvlc_reference_caffenet`. -Now we can train! (You can fine-tune in CPU mode by leaving out the `-gpu` flag.) +Now we can train! The key to fine-tuning is the `-weights` argument in the +command below, which tells Caffe that we want to load weights from a pre-trained +Caffe model. + +(You can fine-tune in CPU mode by leaving out the `-gpu` flag.) caffe % ./build/tools/caffe train -solver models/finetune_flickr_style/solver.prototxt -weights models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel -gpu 0 diff --git a/examples/mnist/convert_mnist_data.cpp b/examples/mnist/convert_mnist_data.cpp index 5e8c1d6c85c..fc66a51c6f2 100644 --- a/examples/mnist/convert_mnist_data.cpp +++ b/examples/mnist/convert_mnist_data.cpp @@ -27,12 +27,15 @@ #include // NOLINT(readability/streams) #include +#include "boost/scoped_ptr.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" #include "caffe/util/format.hpp" #if defined(USE_LEVELDB) && defined(USE_LMDB) using namespace caffe; // NOLINT(build/namespaces) +using boost::scoped_ptr; using std::string; DEFINE_string(backend, "lmdb", "The backend for storing the result"); @@ -72,43 +75,10 @@ void convert_dataset(const char* image_filename, const char* label_filename, image_file.read(reinterpret_cast(&cols), 4); cols = swap_endian(cols); - // lmdb - MDB_env *mdb_env; - MDB_dbi mdb_dbi; - MDB_val mdb_key, mdb_data; - MDB_txn *mdb_txn; - // leveldb - leveldb::DB* db; - leveldb::Options options; - options.error_if_exists = true; - options.create_if_missing = true; - options.write_buffer_size = 268435456; - leveldb::WriteBatch* batch = NULL; - - // Open db - if (db_backend == "leveldb") { // leveldb - LOG(INFO) << "Opening leveldb " << db_path; - leveldb::Status status = leveldb::DB::Open( - options, db_path, &db); - CHECK(status.ok()) << "Failed to open leveldb " << db_path - << ". Is it already existing?"; - batch = new leveldb::WriteBatch(); - } else if (db_backend == "lmdb") { // lmdb - LOG(INFO) << "Opening lmdb " << db_path; - CHECK_EQ(mkdir(db_path, 0744), 0) - << "mkdir " << db_path << "failed"; - CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; - CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS) // 1TB - << "mdb_env_set_mapsize failed"; - CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) - << "mdb_env_open failed"; - CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) - << "mdb_open failed. Does the lmdb already exist? "; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + + scoped_ptr db(db::GetDB(db_backend)); + db->Open(db_path, db::NEW); + scoped_ptr txn(db->NewTransaction()); // Storing to db char label; @@ -130,52 +100,19 @@ void convert_dataset(const char* image_filename, const char* label_filename, string key_str = caffe::format_int(item_id, 8); datum.SerializeToString(&value); - // Put in db - if (db_backend == "leveldb") { // leveldb - batch->Put(key_str, value); - } else if (db_backend == "lmdb") { // lmdb - mdb_data.mv_size = value.size(); - mdb_data.mv_data = reinterpret_cast(&value[0]); - mdb_key.mv_size = key_str.size(); - mdb_key.mv_data = reinterpret_cast(&key_str[0]); - CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) - << "mdb_put failed"; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + txn->Put(key_str, value); if (++count % 1000 == 0) { - // Commit txn - if (db_backend == "leveldb") { // leveldb - db->Write(leveldb::WriteOptions(), batch); - delete batch; - batch = new leveldb::WriteBatch(); - } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) - << "mdb_txn_commit failed"; - CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) - << "mdb_txn_begin failed"; - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } + txn->Commit(); } } // write the last batch if (count % 1000 != 0) { - if (db_backend == "leveldb") { // leveldb - db->Write(leveldb::WriteOptions(), batch); - delete batch; - delete db; - } else if (db_backend == "lmdb") { // lmdb - CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed"; - mdb_close(mdb_env, mdb_dbi); - mdb_env_close(mdb_env); - } else { - LOG(FATAL) << "Unknown db backend " << db_backend; - } - LOG(ERROR) << "Processed " << count << " files."; + txn->Commit(); } + LOG(INFO) << "Processed " << count << " files."; delete[] pixels; + db->Close(); } int main(int argc, char** argv) { @@ -183,6 +120,8 @@ int main(int argc, char** argv) { namespace gflags = google; #endif + FLAGS_alsologtostderr = 1; + gflags::SetUsageMessage("This script converts the MNIST dataset to\n" "the lmdb/leveldb format used by Caffe to load data.\n" "Usage:\n" diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md index b87a0f53c7a..35952155a30 100644 --- a/examples/mnist/readme.md +++ b/examples/mnist/readme.md @@ -248,7 +248,7 @@ These messages tell you the details about each layer, its connections and its ou I1203 solver.cpp:36] Solver scaffolding done. I1203 solver.cpp:44] Solving LeNet -Based on the solver setting, we will print the training loss function every 100 iterations, and test the network every 1000 iterations. You will see messages like this: +Based on the solver setting, we will print the training loss function every 100 iterations, and test the network every 500 iterations. You will see messages like this: I1203 solver.cpp:204] Iteration 100, lr = 0.00992565 I1203 solver.cpp:66] Iteration 100, loss = 0.26044 diff --git a/include/caffe/layers/crop_layer.hpp b/include/caffe/layers/crop_layer.hpp index 5c605b2ae9e..c4fda1220c3 100644 --- a/include/caffe/layers/crop_layer.hpp +++ b/include/caffe/layers/crop_layer.hpp @@ -44,6 +44,7 @@ class CropLayer : public Layer { vector offsets; private: + // Recursive copy function. void crop_copy(const vector*>& bottom, const vector*>& top, const vector& offsets, @@ -53,6 +54,14 @@ class CropLayer : public Layer { Dtype* dest_data, bool is_forward); + // Recursive copy function: this is similar to crop_copy() but loops over all + // but the last two dimensions to allow for ND cropping while still relying on + // a CUDA kernel for the innermost two dimensions for performance reasons. An + // alterantive implementation could rely on the kernel more by passing + // offsets, but this is problematic because of its variable length. + // Since in the standard (N,C,W,H) case N,C are usually not cropped a speedup + // could be achieved by not looping the application of the copy_kernel around + // these dimensions. void crop_copy_gpu(const vector*>& bottom, const vector*>& top, const vector& offsets, diff --git a/include/caffe/layers/parameter_layer.hpp b/include/caffe/layers/parameter_layer.hpp new file mode 100644 index 00000000000..188b92acbe2 --- /dev/null +++ b/include/caffe/layers/parameter_layer.hpp @@ -0,0 +1,45 @@ +#ifndef CAFFE_PARAMETER_LAYER_HPP_ +#define CAFFE_PARAMETER_LAYER_HPP_ + +#include + +#include "caffe/layer.hpp" + +namespace caffe { + +template +class ParameterLayer : public Layer { + public: + explicit ParameterLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top) { + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + this->blobs_.resize(1); + this->blobs_[0].reset(new Blob()); + this->blobs_[0]->Reshape(this->layer_param_.parameter_param().shape()); + } + top[0]->Reshape(this->layer_param_.parameter_param().shape()); + } + virtual void Reshape(const vector*>& bottom, + const vector*>& top) { } + virtual inline const char* type() const { return "Parameter"; } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top) { + top[0]->ShareData(*(this->blobs_[0])); + top[0]->ShareDiff(*(this->blobs_[0])); + } + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) + { } +}; + +} // namespace caffe + +#endif diff --git a/include/caffe/layers/python_layer.hpp b/include/caffe/layers/python_layer.hpp index b839d52684e..66dbbdf13b8 100644 --- a/include/caffe/layers/python_layer.hpp +++ b/include/caffe/layers/python_layer.hpp @@ -26,6 +26,7 @@ class PythonLayer : public Layer { } self_.attr("param_str") = bp::str( this->layer_param_.python_param().param_str()); + self_.attr("phase") = static_cast(this->phase_); self_.attr("setup")(bottom, top); } virtual void Reshape(const vector*>& bottom, diff --git a/include/caffe/util/db_lmdb.hpp b/include/caffe/util/db_lmdb.hpp index 4e1568ace50..ee370322383 100644 --- a/include/caffe/util/db_lmdb.hpp +++ b/include/caffe/util/db_lmdb.hpp @@ -3,6 +3,7 @@ #define CAFFE_UTIL_DB_LMDB_HPP #include +#include #include "lmdb.h" @@ -54,14 +55,16 @@ class LMDBCursor : public Cursor { class LMDBTransaction : public Transaction { public: - explicit LMDBTransaction(MDB_dbi* mdb_dbi, MDB_txn* mdb_txn) - : mdb_dbi_(mdb_dbi), mdb_txn_(mdb_txn) { } + explicit LMDBTransaction(MDB_env* mdb_env) + : mdb_env_(mdb_env) { } virtual void Put(const string& key, const string& value); - virtual void Commit() { MDB_CHECK(mdb_txn_commit(mdb_txn_)); } + virtual void Commit(); private: - MDB_dbi* mdb_dbi_; - MDB_txn* mdb_txn_; + MDB_env* mdb_env_; + vector keys, values; + + void DoubleMapSize(); DISABLE_COPY_AND_ASSIGN(LMDBTransaction); }; diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index a2c46a123aa..32b5d921094 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -26,6 +26,19 @@ #define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x)) #endif +/* Fix to avoid registration warnings in pycaffe (#3960) */ +#define BP_REGISTER_SHARED_PTR_TO_PYTHON(PTR) do { \ + const boost::python::type_info info = \ + boost::python::type_id >(); \ + const boost::python::converter::registration* reg = \ + boost::python::converter::registry::query(info); \ + if (reg == NULL) { \ + bp::register_ptr_to_python >(); \ + } else if ((*reg).m_to_python == NULL) { \ + bp::register_ptr_to_python >(); \ + } \ +} while (0) + namespace bp = boost::python; namespace caffe { @@ -255,7 +268,7 @@ BOOST_PYTHON_MODULE(_caffe) { .def("_set_input_arrays", &Net_SetInputArrays, bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) .def("save", &Net_Save); - bp::register_ptr_to_python > >(); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Net); bp::class_, shared_ptr >, boost::noncopyable>( "Blob", bp::no_init) @@ -275,7 +288,7 @@ BOOST_PYTHON_MODULE(_caffe) { NdarrayCallPolicies())) .add_property("diff", bp::make_function(&Blob::mutable_cpu_diff, NdarrayCallPolicies())); - bp::register_ptr_to_python > >(); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Blob); bp::class_, shared_ptr >, boost::noncopyable>("Layer", bp::init()) @@ -284,7 +297,7 @@ BOOST_PYTHON_MODULE(_caffe) { .def("setup", &Layer::LayerSetUp) .def("reshape", &Layer::Reshape) .add_property("type", bp::make_function(&Layer::type)); - bp::register_ptr_to_python > >(); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Layer); bp::class_("LayerParameter", bp::no_init); @@ -299,7 +312,7 @@ BOOST_PYTHON_MODULE(_caffe) { .def("step", &Solver::Step) .def("restore", &Solver::Restore) .def("snapshot", &Solver::Snapshot); - bp::register_ptr_to_python > >(); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver); bp::class_, bp::bases >, shared_ptr >, boost::noncopyable>( diff --git a/python/caffe/classifier.py b/python/caffe/classifier.py index 537193db8f8..ea29fed86f9 100644 --- a/python/caffe/classifier.py +++ b/python/caffe/classifier.py @@ -79,6 +79,7 @@ def predict(self, inputs, oversample=True): -self.crop_dims / 2.0, self.crop_dims / 2.0 ]) + crop = crop.astype(int) input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :] # Classify diff --git a/python/caffe/draw.py b/python/caffe/draw.py index cfa3fc5b1fb..61205ca9f37 100644 --- a/python/caffe/draw.py +++ b/python/caffe/draw.py @@ -142,7 +142,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True): ------- pydot graph object """ - pydot_graph = pydot.Dot(caffe_net.name, + pydot_graph = pydot.Dot(caffe_net.name if caffe_net.name else 'Net', graph_type='digraph', rankdir=rankdir) pydot_nodes = {} diff --git a/python/caffe/io.py b/python/caffe/io.py index 75310589cec..cee5ace2e88 100644 --- a/python/caffe/io.py +++ b/python/caffe/io.py @@ -63,7 +63,7 @@ def blobprotovector_str_to_arraylist(str): return [blobproto_to_array(blob) for blob in vec.blobs] -def array_to_datum(arr, label=0): +def array_to_datum(arr, label=None): """Converts a 3-dimensional array to datum. If the array has dtype uint8, the output data will be encoded as a string. Otherwise, the output data will be stored in float format. @@ -76,7 +76,8 @@ def array_to_datum(arr, label=0): datum.data = arr.tostring() else: datum.float_data.extend(arr.flat) - datum.label = label + if label is not None: + datum.label = label return datum diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 63de4cce4b2..5fb1f0b3fb1 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -32,7 +32,7 @@ def param_name_dict(): # get all parameter names (typically underscore case) and corresponding # type names (typically camel case), which contain the layer names # (note that not all parameters correspond to layers, but we'll ignore that) - param_names = [s for s in dir(layer) if s.endswith('_param')] + param_names = [f.name for f in layer.DESCRIPTOR.fields if f.name.endswith('_param')] param_type_names = [type(getattr(layer, s)).__name__ for s in param_names] # strip the final '_param' or 'Parameter' param_names = [s[:-len('_param')] for s in param_names] diff --git a/python/caffe/test/test_io.py b/python/caffe/test/test_io.py index 8c86ef75fb2..4a16b5b9128 100644 --- a/python/caffe/test/test_io.py +++ b/python/caffe/test/test_io.py @@ -39,3 +39,18 @@ def test_scalar(self): arr = caffe.io.blobproto_to_array(blob) self.assertEqual(arr, 123) + + +class TestArrayToDatum(unittest.TestCase): + + def test_label_none_size(self): + # Set label + d1 = caffe.io.array_to_datum( + np.ones((10,10,3)), label=1) + # Don't set label + d2 = caffe.io.array_to_datum( + np.ones((10,10,3))) + # Not setting the label should result in a smaller object + self.assertGreater( + len(d1.SerializeToString()), + len(d2.SerializeToString())) diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py index 5669451fc68..9700edc521d 100644 --- a/python/caffe/test/test_python_layer.py +++ b/python/caffe/test/test_python_layer.py @@ -44,6 +44,18 @@ def forward(self, bottom, top): def backward(self, top, propagate_down, bottom): self.blobs[0].diff[0] = 1 +class PhaseLayer(caffe.Layer): + """A layer for checking attribute `phase`""" + + def setup(self, bottom, top): + pass + + def reshape(self, bootom, top): + top[0].reshape() + + def forward(self, bottom, top): + top[0].data[()] = self.phase + def python_net_file(): with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: f.write("""name: 'pythonnet' force_backward: true @@ -76,6 +88,14 @@ def parameter_net_file(): """) return f.name +def phase_net_file(): + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: + f.write("""name: 'pythonnet' force_backward: true + layer { type: 'Python' name: 'layer' top: 'phase' + python_param { module: 'test_python_layer' layer: 'PhaseLayer' } } + """) + return f.name + @unittest.skipIf('Python' not in caffe.layer_type_list(), 'Caffe built without Python layer support') @@ -144,3 +164,9 @@ def test_parameter(self): self.assertEqual(layer.blobs[0].data[0], 1) os.remove(net_file) + + def test_phase(self): + net_file = phase_net_file() + for phase in caffe.TRAIN, caffe.TEST: + net = caffe.Net(net_file, phase) + self.assertEqual(net.forward()['phase'], phase) diff --git a/scripts/download_model_binary.py b/scripts/download_model_binary.py index 66f72f2477e..fcdbb5a91a2 100755 --- a/scripts/download_model_binary.py +++ b/scripts/download_model_binary.py @@ -60,7 +60,7 @@ def valid_dirname(dirname): # Closure-d function for checking SHA1. def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']): - with open(filename, 'r') as f: + with open(filename, 'rb') as f: return hashlib.sha1(f.read()).hexdigest() == sha1 # Check if model exists. diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index c86fd5d1d94..4a34e4c5856 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -30,7 +30,9 @@ void Blob::Reshape(const vector& shape) { int* shape_data = static_cast(shape_data_->mutable_cpu_data()); for (int i = 0; i < shape.size(); ++i) { CHECK_GE(shape[i], 0); - CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; + if (count_ != 0) { + CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; + } count_ *= shape[i]; shape_[i] = shape[i]; shape_data[i] = shape[i]; diff --git a/src/caffe/layers/crop_layer.cpp b/src/caffe/layers/crop_layer.cpp index e81bdd732f3..aecdcd63194 100644 --- a/src/caffe/layers/crop_layer.cpp +++ b/src/caffe/layers/crop_layer.cpp @@ -15,8 +15,7 @@ namespace caffe { template void CropLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - // All logic that depends only on the number of dimensions is here, - // the rest is in Reshape because it depends on Blob size. + // LayerSetup() handles the number of dimensions; Reshape() handles the sizes. // bottom[0] supplies the data // bottom[1] supplies the size const CropParameter& param = this->layer_param_.crop_param(); @@ -40,41 +39,35 @@ void CropLayer::Reshape(const vector*>& bottom, int input_dim = bottom[0]->num_axes(); const int start_axis = bottom[0]->CanonicalAxisIndex(param.axis()); - // initialize all offsets to 0 + // Initialize offsets to 0 and the new shape to the current shape of the data. offsets = vector(input_dim, 0); - // initialize new shape to bottom[0] vector new_shape(bottom[0]->shape()); - // apply crops + // Determine crop offsets and the new shape post-crop. for (int i = 0; i < input_dim; ++i) { int crop_offset = 0; - int new_size = bottom[0]->shape(i); + int new_size = bottom[0]->shape(i); if (i >= start_axis) { new_size = bottom[1]->shape(i); - if (param.offset_size() == 1) { - // if only one crop value is supplied, crop all dimensions after axis - // by this crop value + // If only one offset is given, all crops have the same offset. crop_offset = param.offset(0); } else if (param.offset_size() > 1) { - // crop values specified must be equal to the number of dimensions - // following axis + // For several offsets, the number of offsets must be equal to the + // number of dimensions to crop, that is dimensions after the axis. crop_offset = param.offset(i - start_axis); } + // Check that the crop and offset are within the dimension's bounds. + CHECK_GE(bottom[0]->shape(i) - crop_offset, bottom[1]->shape(i)) + << "the crop for dimension " << i << " is out-of-bounds with " + << "size " << bottom[1]->shape(i) << " and offset " << crop_offset; } - // Check that the image we are cropping minus the margin is bigger - // than the destination image. - CHECK_GE(bottom[0]->shape(i) - crop_offset, - bottom[1]->shape(i)) - << "invalid crop parameters in dimension: " << i; - // Now set new size and offsets new_shape[i] = new_size; offsets[i] = crop_offset; } top[0]->Reshape(new_shape); } -// recursive copy function template void CropLayer::crop_copy(const vector*>& bottom, const vector*>& top, diff --git a/src/caffe/layers/crop_layer.cu b/src/caffe/layers/crop_layer.cu index 9ed8f7cce57..f78cecbbeee 100644 --- a/src/caffe/layers/crop_layer.cu +++ b/src/caffe/layers/crop_layer.cu @@ -22,15 +22,6 @@ __global__ void copy_kernel(const int n, const int height, const int width, } } -// recursive copy function, this function is similar to crop_copy but loops -// over all but the last two dimensions. It is implemented this way to allow -// for ND cropping while still relying on a CUDA kernel for the innermost -// two dimensions for performance reasons. -// An alternative way to implement ND cropping relying more on the kernel -// would require passing offsets to the kernel, which is a bit problematic -// because it is of variable length. Since in the standard (N,C,W,H) case -// N,C are usually not cropped a speedup could be achieved by not looping -// the application of the copy_kernel around these dimensions. template void CropLayer::crop_copy_gpu(const vector*>& bottom, const vector*>& top, diff --git a/src/caffe/layers/exp_layer.cpp b/src/caffe/layers/exp_layer.cpp index 1f4a309fe25..0c1b463ae12 100644 --- a/src/caffe/layers/exp_layer.cpp +++ b/src/caffe/layers/exp_layer.cpp @@ -23,7 +23,8 @@ void ExpLayer::LayerSetUp(const vector*>& bottom, const Dtype input_scale = this->layer_param_.exp_param().scale(); const Dtype input_shift = this->layer_param_.exp_param().shift(); inner_scale_ = log_base * input_scale; - outer_scale_ = (input_shift == Dtype(0)) ? Dtype(1) : pow(base, input_shift); + outer_scale_ = (input_shift == Dtype(0)) ? Dtype(1) : + ( (base != Dtype(-1)) ? pow(base, input_shift) : exp(input_shift) ); } template diff --git a/src/caffe/layers/parameter_layer.cpp b/src/caffe/layers/parameter_layer.cpp new file mode 100644 index 00000000000..fbd326f8469 --- /dev/null +++ b/src/caffe/layers/parameter_layer.cpp @@ -0,0 +1,8 @@ +#include "caffe/layers/parameter_layer.hpp" + +namespace caffe { + +INSTANTIATE_CLASS(ParameterLayer); +REGISTER_LAYER_CLASS(Parameter); + +} // namespace caffe diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 23d94c97c07..f0bf594936c 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -427,12 +427,11 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); - bool propagate_down = true; + bool need_backward = blob_need_backward_[blob_id]; // Check if the backpropagation on bottom_id should be skipped - if (layer_param.propagate_down_size() > 0) - propagate_down = layer_param.propagate_down(bottom_id); - const bool need_backward = blob_need_backward_[blob_id] && - propagate_down; + if (layer_param.propagate_down_size() > 0) { + need_backward = layer_param.propagate_down(bottom_id); + } bottom_need_backward_[layer_id].push_back(need_backward); return blob_id; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 148e0bcb4a9..77f5f915965 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -306,7 +306,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 145 (last added: crop_param) +// LayerParameter next available layer-specific ID: 146 (last added: parameter_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -328,7 +328,12 @@ message LayerParameter { // The blobs containing the numeric parameters of the layer. repeated BlobProto blobs = 7; - // Specifies on which bottoms the backpropagation should be skipped. + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // // The size must be either 0 or equal to the number of bottoms. repeated bool propagate_down = 11; @@ -380,6 +385,7 @@ message LayerParameter { optional LRNParameter lrn_param = 118; optional MemoryDataParameter memory_data_param = 119; optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; optional PoolingParameter pooling_param = 121; optional PowerParameter power_param = 122; optional PReLUParameter prelu_param = 131; @@ -870,6 +876,10 @@ message MVNParameter { optional float eps = 3 [default = 1e-9]; } +message ParameterParameter { + optional BlobShape shape = 1; +} + message PoolingParameter { enum PoolMethod { MAX = 0; @@ -984,7 +994,7 @@ message ReshapeParameter { // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } - // reshape_param { shape { dim: -1 dim: 0 dim: 2 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } // optional BlobShape shape = 1; diff --git a/src/caffe/test/test_blob.cpp b/src/caffe/test/test_blob.cpp index 4e231cdee8b..4990475bb2a 100644 --- a/src/caffe/test/test_blob.cpp +++ b/src/caffe/test/test_blob.cpp @@ -53,6 +53,14 @@ TYPED_TEST(BlobSimpleTest, TestReshape) { EXPECT_EQ(this->blob_->count(), 120); } +TYPED_TEST(BlobSimpleTest, TestReshapeZero) { + vector shape(2); + shape[0] = 0; + shape[1] = 5; + this->blob_->Reshape(shape); + EXPECT_EQ(this->blob_->count(), 0); +} + TYPED_TEST(BlobSimpleTest, TestLegacyBlobProtoShapeEquals) { BlobProto blob_proto; diff --git a/src/caffe/test/test_crop_layer.cpp b/src/caffe/test/test_crop_layer.cpp index 45f24e2ee8d..ce2c736f644 100644 --- a/src/caffe/test/test_crop_layer.cpp +++ b/src/caffe/test/test_crop_layer.cpp @@ -91,6 +91,24 @@ TYPED_TEST(CropLayerTest, TestSetupShapeNegativeIndexing) { } } +TYPED_TEST(CropLayerTest, TestDimensionsCheck) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + // Reshape size blob to have incompatible sizes for uncropped dimensions: + // the size blob has more channels than the data blob, but this is fine + // since the channels dimension is not cropped in this configuration. + this->blob_bottom_1_->Reshape(2, 5, 4, 2); + CropLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->num_axes(); ++i) { + if (i < 2) { + EXPECT_EQ(this->blob_bottom_0_->shape(i), this->blob_top_->shape(i)); + } else { + EXPECT_EQ(this->blob_bottom_1_->shape(i), this->blob_top_->shape(i)); + } + } +} + TYPED_TEST(CropLayerTest, TestCropAll) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 1e0788ec127..92fd317fee8 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest { InitNetFromProtoString(proto); } + virtual void InitForcePropNet(bool test_force_true) { + string proto = + "name: 'ForcePropTestNetwork' " + "layer { " + " name: 'data' " + " type: 'DummyData' " + " dummy_data_param { " + " shape { " + " dim: 5 " + " dim: 2 " + " dim: 3 " + " dim: 4 " + " } " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " shape { " + " dim: 5 " + " } " + " data_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layer { " + " name: 'innerproduct' " + " type: 'InnerProduct' " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " bottom: 'data' " + " top: 'innerproduct' "; + if (test_force_true) { + proto += " propagate_down: true "; + } + proto += + "} " + "layer { " + " name: 'loss' " + " bottom: 'innerproduct' " + " bottom: 'label' " + " top: 'cross_entropy_loss' " + " type: 'SigmoidCrossEntropyLoss' " + "} "; + InitNetFromProtoString(proto); + } + int seed_; shared_ptr > net_; }; @@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) { } } +TYPED_TEST(NetTest, TestForcePropagateDown) { + this->InitForcePropNet(false); + vector layer_need_backward = this->net_->layer_need_backward(); + for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) { + const string& layer_name = this->net_->layer_names()[layer_id]; + const vector need_backward = + this->net_->bottom_need_backward()[layer_id]; + if (layer_name == "data") { + ASSERT_EQ(need_backward.size(), 0); + EXPECT_FALSE(layer_need_backward[layer_id]); + } else if (layer_name == "innerproduct") { + ASSERT_EQ(need_backward.size(), 1); + EXPECT_FALSE(need_backward[0]); // data + EXPECT_TRUE(layer_need_backward[layer_id]); + } else if (layer_name == "loss") { + ASSERT_EQ(need_backward.size(), 2); + EXPECT_TRUE(need_backward[0]); // innerproduct + EXPECT_FALSE(need_backward[1]); // label + EXPECT_TRUE(layer_need_backward[layer_id]); + } else { + LOG(FATAL) << "Unknown layer: " << layer_name; + } + } + this->InitForcePropNet(true); + layer_need_backward = this->net_->layer_need_backward(); + for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) { + const string& layer_name = this->net_->layer_names()[layer_id]; + const vector need_backward = + this->net_->bottom_need_backward()[layer_id]; + if (layer_name == "data") { + ASSERT_EQ(need_backward.size(), 0); + EXPECT_FALSE(layer_need_backward[layer_id]); + } else if (layer_name == "innerproduct") { + ASSERT_EQ(need_backward.size(), 1); + EXPECT_TRUE(need_backward[0]); // data + EXPECT_TRUE(layer_need_backward[layer_id]); + } else if (layer_name == "loss") { + ASSERT_EQ(need_backward.size(), 2); + EXPECT_TRUE(need_backward[0]); // innerproduct + EXPECT_FALSE(need_backward[1]); // label + EXPECT_TRUE(layer_need_backward[layer_id]); + } else { + LOG(FATAL) << "Unknown layer: " << layer_name; + } + } +} + } // namespace caffe diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp index dd591f7d204..342f825cec3 100644 --- a/src/caffe/test/test_neuron_layer.cpp +++ b/src/caffe/test/test_neuron_layer.cpp @@ -394,6 +394,26 @@ TYPED_TEST(NeuronLayerTest, TestExpGradient) { this->TestExpGradient(kBase, kScale, kShift); } +TYPED_TEST(NeuronLayerTest, TestExpLayerWithShift) { + typedef typename TypeParam::Dtype Dtype; + // Test default base of "-1" -- should actually set base := e, + // with a non-zero shift + const Dtype kBase = -1; + const Dtype kScale = 1; + const Dtype kShift = 1; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradientWithShift) { + typedef typename TypeParam::Dtype Dtype; + // Test default base of "-1" -- should actually set base := e, + // with a non-zero shift + const Dtype kBase = -1; + const Dtype kScale = 1; + const Dtype kShift = 1; + this->TestExpGradient(kBase, kScale, kShift); +} + TYPED_TEST(NeuronLayerTest, TestExpLayerBase2) { typedef typename TypeParam::Dtype Dtype; const Dtype kBase = 2; diff --git a/src/caffe/util/db_lmdb.cpp b/src/caffe/util/db_lmdb.cpp index cd17447989d..55e14d21275 100644 --- a/src/caffe/util/db_lmdb.cpp +++ b/src/caffe/util/db_lmdb.cpp @@ -12,20 +12,8 @@ namespace caffe { namespace db { -#ifdef _MSC_VER -// On Windows lmdb creates file with the full size causing test failures due -// to insufficient disk space. We will reduce lmdb size to make tests pass. -const size_t LMDB_MAP_SIZE = 104857600; // 100 MB -// Constant will overflow on 32-bit build, assert that we are using correct -// build. -static_assert(sizeof(size_t) >= 8, "LMDB size overflow."); -#else -const size_t LMDB_MAP_SIZE = 1099511627776; // 1 TB -#endif - void LMDB::Open(const string& source, Mode mode) { MDB_CHECK(mdb_env_create(&mdb_env_)); - MDB_CHECK(mdb_env_set_mapsize(mdb_env_, LMDB_MAP_SIZE)); if (mode == NEW) { CHECK_EQ(mkdir(source.c_str(), 0744), 0) << "mkdir " << source << "failed"; } @@ -62,19 +50,61 @@ LMDBCursor* LMDB::NewCursor() { } LMDBTransaction* LMDB::NewTransaction() { - MDB_txn* mdb_txn; - MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn)); - MDB_CHECK(mdb_dbi_open(mdb_txn, NULL, 0, &mdb_dbi_)); - return new LMDBTransaction(&mdb_dbi_, mdb_txn); + return new LMDBTransaction(mdb_env_); } void LMDBTransaction::Put(const string& key, const string& value) { - MDB_val mdb_key, mdb_value; - mdb_key.mv_data = const_cast(key.data()); - mdb_key.mv_size = key.size(); - mdb_value.mv_data = const_cast(value.data()); - mdb_value.mv_size = value.size(); - MDB_CHECK(mdb_put(mdb_txn_, *mdb_dbi_, &mdb_key, &mdb_value, 0)); + keys.push_back(key); + values.push_back(value); +} + +void LMDBTransaction::Commit() { + MDB_dbi mdb_dbi; + MDB_val mdb_key, mdb_data; + MDB_txn *mdb_txn; + + // Initialize MDB variables + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn)); + MDB_CHECK(mdb_dbi_open(mdb_txn, NULL, 0, &mdb_dbi)); + + bool out_of_memory = false; + for (int i = 0; i < keys.size(); i++) { + mdb_key.mv_size = keys[i].size(); + mdb_key.mv_data = const_cast(keys[i].data()); + mdb_data.mv_size = values[i].size(); + mdb_data.mv_data = const_cast(values[i].data()); + + int put_rc = mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0); + if (put_rc == MDB_MAP_FULL) { + out_of_memory = true; + break; + } else { + // Failed for some other reason + MDB_CHECK(put_rc); + } + } + + if (!out_of_memory) { + // Commit the transaction + MDB_CHECK(mdb_txn_commit(mdb_txn)); + mdb_dbi_close(mdb_env_, mdb_dbi); + keys.clear(); + values.clear(); + } else { + // Double the map size and retry + mdb_txn_abort(mdb_txn); + mdb_dbi_close(mdb_env_, mdb_dbi); + DoubleMapSize(); + Commit(); + } +} + +void LMDBTransaction::DoubleMapSize() { + struct MDB_envinfo current_info; + MDB_CHECK(mdb_env_info(mdb_env_, ¤t_info)); + size_t new_size = current_info.me_mapsize * 2; + DLOG(INFO) << "Doubling LMDB map size to " << (new_size>>20) << "MB ..."; + MDB_CHECK(mdb_env_set_mapsize(mdb_env_, new_size)); } } // namespace db diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 80ad762fe51..3af434b4053 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -33,7 +33,7 @@ DEFINE_string(gpu, "", DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", - "The model definition protocol buffer text file.."); + "The model definition protocol buffer text file."); DEFINE_string(snapshot, "", "Optional; the snapshot solver state to resume training."); DEFINE_string(weights, "", diff --git a/tools/extra/plot_training_log.py.example b/tools/extra/plot_training_log.py.example index 4d3ed0d15a9..79924ae5a5a 100755 --- a/tools/extra/plot_training_log.py.example +++ b/tools/extra/plot_training_log.py.example @@ -10,7 +10,8 @@ import matplotlib.legend as lgd import matplotlib.markers as mks def get_log_parsing_script(): - dirname = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) + dirname = os.path.dirname(os.path.abspath(inspect.getfile( + inspect.currentframe()))) return dirname + '/parse_log.sh' def get_log_file_suffix(): @@ -61,16 +62,17 @@ def get_data_file_type(chart_type): return data_file_type def get_data_file(chart_type, path_to_log): - return os.path.basename(path_to_log) + '.' + get_data_file_type(chart_type).lower() + return (os.path.basename(path_to_log) + '.' + + get_data_file_type(chart_type).lower()) def get_field_descriptions(chart_type): description = get_chart_type_description(chart_type).split( get_chart_type_description_separator()) y_axis_field = description[0] x_axis_field = description[1] - return x_axis_field, y_axis_field + return x_axis_field, y_axis_field -def get_field_indecies(x_axis_field, y_axis_field): +def get_field_indices(x_axis_field, y_axis_field): data_file_type = get_data_file_type(chart_type) fields = create_field_index()[0][data_file_type] return fields[x_axis_field], fields[y_axis_field] @@ -111,7 +113,7 @@ def plot_chart(chart_type, path_to_png, path_to_log_list): os.system('%s %s' % (get_log_parsing_script(), path_to_log)) data_file = get_data_file(chart_type, path_to_log) x_axis_field, y_axis_field = get_field_descriptions(chart_type) - x, y = get_field_indecies(x_axis_field, y_axis_field) + x, y = get_field_indices(x_axis_field, y_axis_field) data = load_data(data_file, x, y) ## TODO: more systematic color cycle for lines color = [random.random(), random.random(), random.random()] @@ -138,8 +140,8 @@ def plot_chart(chart_type, path_to_png, path_to_log_list): plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space plt.title(get_chart_type_description(chart_type)) plt.xlabel(x_axis_field) - plt.ylabel(y_axis_field) - plt.savefig(path_to_png) + plt.ylabel(y_axis_field) + plt.savefig(path_to_png) plt.show() def print_help(): @@ -160,28 +162,30 @@ Supported chart types:""" % (len(get_supported_chart_types()) - 1, num = len(supported_chart_types) for i in xrange(num): print ' %d: %s' % (i, supported_chart_types[i]) - exit + sys.exit() def is_valid_chart_type(chart_type): return chart_type >= 0 and chart_type < len(get_supported_chart_types()) - + if __name__ == '__main__': if len(sys.argv) < 4: print_help() else: chart_type = int(sys.argv[1]) if not is_valid_chart_type(chart_type): + print '%s is not a valid chart type.' % chart_type print_help() path_to_png = sys.argv[2] if not path_to_png.endswith('.png'): print 'Path must ends with png' % path_to_png - exit + sys.exit() path_to_logs = sys.argv[3:] for path_to_log in path_to_logs: if not os.path.exists(path_to_log): print 'Path does not exist: %s' % path_to_log - exit + sys.exit() if not path_to_log.endswith(get_log_file_suffix()): + print 'Log file must end in %s.' % get_log_file_suffix() print_help() ## plot_chart accpets multiple path_to_logs plot_chart(chart_type, path_to_png, path_to_logs) diff --git a/windows/libcaffe/libcaffe.vcxproj b/windows/libcaffe/libcaffe.vcxproj index fce0a30ed43..292a844db8a 100644 --- a/windows/libcaffe/libcaffe.vcxproj +++ b/windows/libcaffe/libcaffe.vcxproj @@ -144,6 +144,7 @@ + @@ -245,6 +246,7 @@ + diff --git a/windows/libcaffe/libcaffe.vcxproj.filters b/windows/libcaffe/libcaffe.vcxproj.filters index f781b823f6b..cbe4c60c944 100644 --- a/windows/libcaffe/libcaffe.vcxproj.filters +++ b/windows/libcaffe/libcaffe.vcxproj.filters @@ -267,6 +267,9 @@ src\layers + + src\layers + src\layers @@ -521,6 +524,9 @@ include\layers + + include\layers + include\layers diff --git a/windows/nuget.config b/windows/nuget.config index fc77aae0d3f..ea7ca993c5a 100644 --- a/windows/nuget.config +++ b/windows/nuget.config @@ -1,4 +1,7 @@  + + + ..\..\NugetPackages \ No newline at end of file