Skip to content

Commit

Permalink
fix: replace mutex with atomic bool for async op and add Aluminum sub…
Browse files Browse the repository at this point in the history
…module update (#67)
  • Loading branch information
wangraying authored Aug 19, 2021
1 parent 006f8fe commit 64d5cf0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
temp_tensor.clone_from(&t.raw, torch_stream as u64);

let src_ready_event = CUDA_EVENT_POOL.take().event;
let dst_ready_event = CUDA_EVENT_POOL.take().event;

unsafe {
cpp::cpp!([
Expand All @@ -100,17 +99,26 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous {
PeerSelectionMode::ShiftOne => {
unimplemented!()
}
};
};

let comm_ready_event = CUDA_EVENT_POOL.take().event;

unsafe {
cpp::cpp!([comm_stream as "cudaStream_t"] { CUDACHECK(cudaStreamSynchronize(comm_stream)); });
cpp::cpp!([
comm_ready_event as "cudaEvent_t",
comm_stream as "cudaStream_t"]
{
CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream));
CUDACHECK(cudaEventSynchronize(comm_ready_event));
});
}

if c.check_abort() {
tracing::debug!("async model average on process {} early stopped due to communicator abortion", c.rank);
return
}


// do we need to wait default stream?
unsafe {
cpp::cpp!([
Expand Down
11 changes: 6 additions & 5 deletions bagua-core/bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::datatypes::{
};
use crate::BaguaCoreError;
use itertools::Itertools;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

#[derive(Clone, Debug)]
Expand All @@ -13,7 +13,7 @@ pub struct BaguaCommunicatorInner {
pub rank: usize,
pub nranks: usize,
pub device_id: usize,
pub aborted: Arc<Mutex<bool>>,
pub aborted: Arc<AtomicBool>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -54,7 +54,7 @@ impl BaguaSingleCommunicator {
rank,
nranks,
device_id,
aborted: Arc::new(Mutex::new(false)),
aborted: Arc::new(AtomicBool::new(false)),
}),
}
}
Expand Down Expand Up @@ -456,7 +456,8 @@ impl BaguaCommunicatorInner {
pub fn abort(&self) {
let communicator_ptr = self.comm_ptr;

*self.aborted.lock() = true;
self.aborted.store(true, Ordering::Relaxed);

unsafe {
cpp::cpp!([communicator_ptr as "Al::NCCLCommunicator*"]
{
Expand All @@ -466,7 +467,7 @@ impl BaguaCommunicatorInner {
}

pub fn check_abort(&self) -> bool {
*self.aborted.lock()
self.aborted.load(Ordering::Relaxed)
}

pub fn broadcast(&self, tensor: &mut dyn RawBaguaTensor, root_rank: i32) {
Expand Down
2 changes: 1 addition & 1 deletion bagua-core/bagua-core-internal/third_party/Aluminum
Submodule Aluminum updated from edb7e1 to dc4ff7

0 comments on commit 64d5cf0

Please sign in to comment.