diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp index 85fc2b55984..6c496c884e3 100644 --- a/include/caffe/parallel.hpp +++ b/include/caffe/parallel.hpp @@ -93,7 +93,10 @@ class P2PSync : public GPUParams, public Solver::Callback, return solver_; } - void run(const vector& gpus); + void Run(const vector& gpus); + void Prepare(const vector& gpus, + vector > >* syncs); + inline const int initial_iter() const { return initial_iter_; } protected: void on_start(); diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 62f5d738593..5bc41c6a6e5 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -380,7 +380,8 @@ void P2PSync::on_gradients_ready() { } template -void P2PSync::run(const vector& gpus) { +void P2PSync::Prepare(const vector& gpus, + vector > >* syncs) { // Pair devices for map-reduce synchronization vector pairs; DevicePair::compute(gpus, &pairs); @@ -391,15 +392,14 @@ void P2PSync::run(const vector& gpus) { LOG(INFO)<< "GPUs pairs " << s.str(); SolverParameter param(solver_->param()); - vector > > syncs(gpus.size()); // Build the GPU tree by finding the parent for each solver for (int attempts = 0; attempts < pairs.size(); ++attempts) { for (int i = 1; i < pairs.size(); ++i) { - if (!syncs[i].get()) { + if (!syncs->at(i).get()) { P2PSync* parent = NULL; - for (int j = 0; j < syncs.size(); ++j) { - P2PSync* sync = j == 0 ? this : syncs[j].get(); + for (int j = 0; j < syncs->size(); ++j) { + P2PSync* sync = j == 0 ? this : syncs->at(j).get(); if (sync) { const SolverParameter& p = sync->solver()->param(); if (p.device_id() == pairs[i].parent()) { @@ -409,12 +409,18 @@ void P2PSync::run(const vector& gpus) { } if (parent) { param.set_device_id(pairs[i].device()); - syncs[i].reset(new P2PSync(solver_, parent, param)); - parent->children_.push_back((P2PSync*) syncs[i].get()); + syncs->at(i).reset(new P2PSync(solver_, parent, param)); + parent->children_.push_back((P2PSync*) syncs->at(i).get()); } } } } +} + +template +void P2PSync::Run(const vector& gpus) { + vector > > syncs(gpus.size()); + Prepare(gpus, &syncs); LOG(INFO)<< "Starting Optimization"; diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 09ec3a7e918..975a8f0f88a 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -204,7 +204,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { Caffe::set_solver_count(gpus.size()); this->sync_.reset(new P2PSync( this->solver_, NULL, this->solver_->param())); - this->sync_->run(gpus); + this->sync_->Run(gpus); Caffe::set_solver_count(1); } if (snapshot) { diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 95b2f82c4be..5d9331f0c22 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -214,7 +214,7 @@ int train() { if (gpus.size() > 1) { caffe::P2PSync sync(solver, NULL, solver->param()); - sync.run(gpus); + sync.Run(gpus); } else { LOG(INFO) << "Starting Optimization"; solver->Solve();