Skip to content

Commit

Permalink
fix: fix async algorithm aborting (#78)
Browse files Browse the repository at this point in the history
* tmp

* .

* update

* format

* remove unused
  • Loading branch information
wangraying authored Aug 30, 2021
1 parent dbf56fb commit 6eaeb80
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
false,
&mut |c, t| {

if c.check_abort() {
tracing::debug!("process {} exits due to previous abortion", c.rank);
return
}

let start_time = std::time::Instant::now();
tracing::debug!("async model average start");

Expand Down Expand Up @@ -103,22 +108,32 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {

let comm_ready_event = CUDA_EVENT_POOL.take().event;

unsafe {
let ret = unsafe {
cpp::cpp!([
comm_ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t"]
comm_stream as "cudaStream_t"] -> bool as "bool"
{
CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream));
CUDACHECK(cudaEventSynchronize(comm_ready_event));
});
cudaError_t err = cudaEventSynchronize(comm_ready_event);
if (err != cudaSuccess) {
printf("Warning: Cuda error %s:%d '%s'\n", __FILE__,__LINE__,cudaGetErrorString(err));
return false;
}
return true;
})
};

if !ret {
tracing::debug!("process {} early stopped due to communication failure", c.rank);
c.set_abort();
return
}

if c.check_abort() {
tracing::debug!("async model average on process {} early stopped due to communicator abortion", c.rank);
tracing::debug!("process {} early stopped due to communicator abortion", c.rank);
return
}


// do we need to wait default stream?
unsafe {
cpp::cpp!([
Expand Down
6 changes: 5 additions & 1 deletion bagua-core/bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ impl BaguaCommunicatorInner {
pub fn abort(&self) {
let communicator_ptr = self.comm_ptr;

self.aborted.store(true, Ordering::Relaxed);
self.set_abort();

unsafe {
cpp::cpp!([communicator_ptr as "Al::NCCLCommunicator*"]
Expand All @@ -466,6 +466,10 @@ impl BaguaCommunicatorInner {
}
}

pub fn set_abort(&self) {
self.aborted.store(true, Ordering::Relaxed);
}

pub fn check_abort(&self) -> bool {
self.aborted.load(Ordering::Relaxed)
}
Expand Down

0 comments on commit 6eaeb80

Please sign in to comment.