Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose SGDSolver to pycaffe #286

Merged
merged 7 commits into from
Apr 9, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
void Init(const SolverParameter& param);
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
virtual ~Solver() {}
inline Net<Dtype>* net() { return net_.get(); }
inline shared_ptr<Net<Dtype> > net() { return net_; }

protected:
// PreSolve is run before any solving iteration starts, allowing one to
Expand Down Expand Up @@ -53,6 +55,8 @@ class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}

protected:
virtual void PreSolve();
Expand Down
2 changes: 1 addition & 1 deletion python/caffe/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .pycaffe import Net
from .pycaffe import Net, SGDSolver
67 changes: 47 additions & 20 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ using boost::python::object;
using boost::python::handle;
using boost::python::vector_indexing_suite;

// for convenience, check that input files can be opened, and raise an
// exception that boost will send to Python if not (caffe could still crash
// later if the input files are disturbed before they are actually used, but
// this saves frustration in most cases)
static void CheckFile(const string& filename) {
std::ifstream f(filename.c_str());
if (!f.good()) {
f.close();
throw std::runtime_error("Could not open file " + filename);
}
f.close();
}

// wrap shared_ptr<Blob<float> > in a class that we construct in C++ and pass
// to Python
Expand Down Expand Up @@ -123,27 +135,16 @@ class CaffeLayer {
// A simple wrapper over CaffeNet that runs the forward process.
struct CaffeNet {
CaffeNet(string param_file, string pretrained_param_file) {
// for convenience, check that the input files can be opened, and raise
// an exception that boost will send to Python if not
// (this function could still crash if the input files are disturbed
// before Net construction)
std::ifstream f(param_file.c_str());
if (!f.good()) {
f.close();
throw std::runtime_error("Could not open file " + param_file);
}
f.close();
f.open(pretrained_param_file.c_str());
if (!f.good()) {
f.close();
throw std::runtime_error("Could not open file " + pretrained_param_file);
}
f.close();
CheckFile(param_file);
CheckFile(pretrained_param_file);

net_.reset(new Net<float>(param_file));
net_->CopyTrainedLayersFrom(pretrained_param_file);
}

CaffeNet(shared_ptr<Net<float> > net)
: net_(net) {}

virtual ~CaffeNet() {}

inline void check_array_against_blob(
Expand Down Expand Up @@ -282,12 +283,31 @@ struct CaffeNet {
shared_ptr<Net<float> > net_;
};

class CaffeSGDSolver {
public:
CaffeSGDSolver(const string& param_file) {
// as in CaffeNet, (as a convenience, not a guarantee), create a Python
// exception if param_file can't be opened
CheckFile(param_file);
solver_.reset(new SGDSolver<float>(param_file));
}

CaffeNet net() { return CaffeNet(solver_->net()); }
void Solve() { return solver_->Solve(); }
void SolveResume(const string& resume_file) {
CheckFile(resume_file);
return solver_->Solve(resume_file);
}

protected:
shared_ptr<SGDSolver<float> > solver_;
};


// The boost python module definition.
BOOST_PYTHON_MODULE(_caffe) {
boost::python::class_<CaffeNet>(
"CaffeNet", boost::python::init<string, string>())
"Net", boost::python::init<string, string>())
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
Expand All @@ -296,11 +316,12 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("set_phase_train", &CaffeNet::set_phase_train)
.def("set_phase_test", &CaffeNet::set_phase_test)
.def("set_device", &CaffeNet::set_device)
.add_property("blobs", &CaffeNet::blobs)
// rename blobs here since the pycaffe.py wrapper will replace it
.add_property("_blobs", &CaffeNet::blobs)
.add_property("layers", &CaffeNet::layers);

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"CaffeBlob", boost::python::no_init)
"Blob", boost::python::no_init)
.add_property("name", &CaffeBlob::name)
.add_property("num", &CaffeBlob::num)
.add_property("channels", &CaffeBlob::channels)
Expand All @@ -311,10 +332,16 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("diff", &CaffeBlobWrap::get_diff);

boost::python::class_<CaffeLayer>(
"CaffeLayer", boost::python::no_init)
"Layer", boost::python::no_init)
.add_property("name", &CaffeLayer::name)
.add_property("blobs", &CaffeLayer::blobs);

boost::python::class_<CaffeSGDSolver, boost::noncopyable>(
"SGDSolver", boost::python::init<string>())
.add_property("net", &CaffeSGDSolver::net)
.def("solve", &CaffeSGDSolver::Solve)
.def("solve", &CaffeSGDSolver::SolveResume);

boost::python::class_<vector<CaffeBlob> >("BlobVec")
.def(vector_indexing_suite<vector<CaffeBlob>, true>());

Expand Down
41 changes: 23 additions & 18 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,31 @@
interface.
"""

from ._caffe import CaffeNet
from ._caffe import Net, SGDSolver
from collections import OrderedDict

class Net(CaffeNet):
# we directly update methods from Net here (rather than using composition or
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface

@property
def _Net_blobs(self):
"""
An OrderedDict (bottom to top, i.e., input to output) of network
blobs indexed by name
"""
return OrderedDict([(bl.name, bl) for bl in self._blobs])

Net.blobs = _Net_blobs

@property
def _Net_params(self):
"""
The direct Python interface to caffe, exposing Forward and Backward
passes, data, gradients, and layer parameters
An OrderedDict (bottom to top, i.e., input to output) of network
parameters indexed by name; each is a list of multiple blobs (e.g.,
weights and biases)
"""
def __init__(self, param_file, pretrained_param_file):
super(Net, self).__init__(param_file, pretrained_param_file)
self._blobs = OrderedDict([(bl.name, bl)
for bl in super(Net, self).blobs])
self.params = OrderedDict([(lr.name, lr.blobs)
for lr in super(Net, self).layers
if len(lr.blobs) > 0])
return OrderedDict([(lr.name, lr.blobs) for lr in self.layers
if len(lr.blobs) > 0])

@property
def blobs(self):
"""
An OrderedDict (bottom to top, i.e., input to output) of network
blobs indexed by name
"""
return self._blobs
Net.params = _Net_params
16 changes: 15 additions & 1 deletion src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,21 @@ namespace caffe {

template <typename Dtype>
Solver<Dtype>::Solver(const SolverParameter& param)
: param_(param), net_(), test_net_() {
: net_(), test_net_() {
Init(param);
}

template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_(), test_net_() {
SolverParameter param;
ReadProtoFromTextFile(param_file, &param);
Init(param);
}

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
param_ = param;
// Scaffolding code
LOG(INFO) << "Creating training net.";
net_.reset(new Net<Dtype>(param_.train_net()));
Expand Down