diff --git a/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs b/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs index c6b1abdc4..6250f004d 100644 --- a/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs +++ b/bagua-core/bagua-core-internal/src/comm_ops/decentralized_full_precision_asynchronous.rs @@ -77,7 +77,6 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { temp_tensor.clone_from(&t.raw, torch_stream as u64); let src_ready_event = CUDA_EVENT_POOL.take().event; - let dst_ready_event = CUDA_EVENT_POOL.take().event; unsafe { cpp::cpp!([ @@ -100,10 +99,18 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { PeerSelectionMode::ShiftOne => { unimplemented!() } - }; + }; + + let comm_ready_event = CUDA_EVENT_POOL.take().event; unsafe { - cpp::cpp!([comm_stream as "cudaStream_t"] { CUDACHECK(cudaStreamSynchronize(comm_stream)); }); + cpp::cpp!([ + comm_ready_event as "cudaEvent_t", + comm_stream as "cudaStream_t"] + { + CUDACHECK(cudaEventRecord(comm_ready_event, comm_stream)); + CUDACHECK(cudaEventSynchronize(comm_ready_event)); + }); } if c.check_abort() { @@ -111,6 +118,7 @@ impl CommOpTrait for DecentralizedFullPrecisionAsynchronous { return } + // do we need to wait default stream? unsafe { cpp::cpp!([ diff --git a/bagua-core/bagua-core-internal/src/communicators/mod.rs b/bagua-core/bagua-core-internal/src/communicators/mod.rs index 47f711add..e1d6bc4b7 100644 --- a/bagua-core/bagua-core-internal/src/communicators/mod.rs +++ b/bagua-core/bagua-core-internal/src/communicators/mod.rs @@ -3,7 +3,7 @@ use crate::datatypes::{ }; use crate::BaguaCoreError; use itertools::Itertools; -use parking_lot::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; #[derive(Clone, Debug)] @@ -13,7 +13,7 @@ pub struct BaguaCommunicatorInner { pub rank: usize, pub nranks: usize, pub device_id: usize, - pub aborted: Arc>, + pub aborted: Arc, } #[derive(Clone, Debug)] @@ -54,7 +54,7 @@ impl BaguaSingleCommunicator { rank, nranks, device_id, - aborted: Arc::new(Mutex::new(false)), + aborted: Arc::new(AtomicBool::new(false)), }), } } @@ -456,7 +456,8 @@ impl BaguaCommunicatorInner { pub fn abort(&self) { let communicator_ptr = self.comm_ptr; - *self.aborted.lock() = true; + self.aborted.store(true, Ordering::Relaxed); + unsafe { cpp::cpp!([communicator_ptr as "Al::NCCLCommunicator*"] { @@ -466,7 +467,7 @@ impl BaguaCommunicatorInner { } pub fn check_abort(&self) -> bool { - *self.aborted.lock() + self.aborted.load(Ordering::Relaxed) } pub fn broadcast(&self, tensor: &mut dyn RawBaguaTensor, root_rank: i32) { diff --git a/bagua-core/bagua-core-internal/third_party/Aluminum b/bagua-core/bagua-core-internal/third_party/Aluminum index edb7e1529..dc4ff784d 160000 --- a/bagua-core/bagua-core-internal/third_party/Aluminum +++ b/bagua-core/bagua-core-internal/third_party/Aluminum @@ -1 +1 @@ -Subproject commit edb7e1529498aa4858b362fc052c78afdcd85a68 +Subproject commit dc4ff784dc4201ee1358b1bace2d6b26d70ef9c0