Skip to content

Commit

Permalink
Merge pull request #3588 from junshi15/P2psyncPrepare
Browse files Browse the repository at this point in the history
Refine P2PSync
  • Loading branch information
longjon committed Mar 6, 2016
2 parents 74cc497 + 0152891 commit 54162f8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
5 changes: 4 additions & 1 deletion include/caffe/parallel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback,
return solver_;
}

void run(const vector<int>& gpus);
void Run(const vector<int>& gpus);
void Prepare(const vector<int>& gpus,
vector<shared_ptr<P2PSync<Dtype> > >* syncs);
inline const int initial_iter() const { return initial_iter_; }

protected:
void on_start();
Expand Down
20 changes: 13 additions & 7 deletions src/caffe/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ void P2PSync<Dtype>::on_gradients_ready() {
}

template<typename Dtype>
void P2PSync<Dtype>::run(const vector<int>& gpus) {
void P2PSync<Dtype>::Prepare(const vector<int>& gpus,
vector<shared_ptr<P2PSync<Dtype> > >* syncs) {
// Pair devices for map-reduce synchronization
vector<DevicePair> pairs;
DevicePair::compute(gpus, &pairs);
Expand All @@ -391,15 +392,14 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
LOG(INFO)<< "GPUs pairs " << s.str();

SolverParameter param(solver_->param());
vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());

// Build the GPU tree by finding the parent for each solver
for (int attempts = 0; attempts < pairs.size(); ++attempts) {
for (int i = 1; i < pairs.size(); ++i) {
if (!syncs[i].get()) {
if (!syncs->at(i).get()) {
P2PSync<Dtype>* parent = NULL;
for (int j = 0; j < syncs.size(); ++j) {
P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get();
for (int j = 0; j < syncs->size(); ++j) {
P2PSync<Dtype>* sync = j == 0 ? this : syncs->at(j).get();
if (sync) {
const SolverParameter& p = sync->solver()->param();
if (p.device_id() == pairs[i].parent()) {
Expand All @@ -409,12 +409,18 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
}
if (parent) {
param.set_device_id(pairs[i].device());
syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param));
parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get());
syncs->at(i).reset(new P2PSync<Dtype>(solver_, parent, param));
parent->children_.push_back((P2PSync<Dtype>*) syncs->at(i).get());
}
}
}
}
}

template<typename Dtype>
void P2PSync<Dtype>::Run(const vector<int>& gpus) {
vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
Prepare(gpus, &syncs);

LOG(INFO)<< "Starting Optimization";

Expand Down
2 changes: 1 addition & 1 deletion src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
Caffe::set_solver_count(gpus.size());
this->sync_.reset(new P2PSync<Dtype>(
this->solver_, NULL, this->solver_->param()));
this->sync_->run(gpus);
this->sync_->Run(gpus);
Caffe::set_solver_count(1);
}
if (snapshot) {
Expand Down
2 changes: 1 addition & 1 deletion tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ int train() {

if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
sync.run(gpus);
sync.Run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();
Expand Down

0 comments on commit 54162f8

Please sign in to comment.