Skip to content

Commit

Permalink
feat: initial support for python op (#2)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `set_xxx_op`s are renamed to `append_xxx_op` now.
  • Loading branch information
NOBLES5E authored Jun 17, 2021
1 parent bf64338 commit e873538
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 52 deletions.
3 changes: 3 additions & 0 deletions bagua-core/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ push.sh
__pycache__/
*.egg-info/
/dist/
/.eggs/
/build/
.data/
7 changes: 4 additions & 3 deletions bagua-core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion bagua-core/bagua-core-c/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bagua-core-c"
version = "0.1.0"
version = "0.1.2"
edition = "2018"

[lib]
Expand Down
5 changes: 4 additions & 1 deletion bagua-core/bagua-core-internal/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bagua-core-internal"
version = "0.1.0"
version = "0.1.2"
authors = ["Xiangru Lian <admin@mail.xrlian.com>"]
edition = "2018"
publish = ["private"]
Expand All @@ -27,6 +27,9 @@ scheduled-thread-pool = "0.2"
serde_json = "1.0"
ureq = "2.1"

[dependencies.pyo3]
version = "0.13.2"

[build-dependencies]
shadow-rs = "0.5"
cpp_build = "0.5"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,13 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
);

if step % comm_interval == 0 {
// TODO: move this to .then() python API instead of hard code this in op
let post_backward_comm_op = BaguaScheduledCommOp {
bucket: bucket.clone(),
op: Arc::new(DecentralizedFullPrecisionSynchronousPostStep {
ops: vec![Arc::new(DecentralizedFullPrecisionSynchronousPostStep {
communicator: self.communicator.clone(),
result_weight: peer_tensor,
}),
})],
event_channel: Default::default(),
};

Expand Down
1 change: 1 addition & 0 deletions bagua-core/bagua-core-internal/src/comm_ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod centralized_full_precision_synchronous;
pub mod centralized_low_precision_synchronous;
pub mod decentralized_full_precision_synchronous;
pub mod python_ffi_op;

use crate::datatypes::BaguaBucket;
use crate::BaguaCommOpChannels;
Expand Down
27 changes: 27 additions & 0 deletions bagua-core/bagua-core-internal/src/comm_ops/python_ffi_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::comm_ops::CommOpTrait;
use crate::communicators::BaguaCommunicator;
use crate::datatypes::{BaguaBucket, BaguaTensorRaw};
use crate::resource_pool::CUDA_DEVICE_MEMORY_POOL;
use crate::BaguaCommOpChannels;
use pyo3::Python;
use std::sync::Arc;

#[derive(Debug)]
pub struct PythonFFIOp {
pub py_callable: pyo3::Py<pyo3::PyAny>,
}

impl CommOpTrait for PythonFFIOp {
fn execute_background_communication(
&self,
bucket: Arc<BaguaBucket>,
_comm_op_channels: &BaguaCommOpChannels,
) {
Python::with_gil(|python| {
let result = self.py_callable.call1(python, (bucket.name.as_str(),));
if let Err(e) = result {
tracing::error!("python ffi op error: {:?}", e);
}
});
}
}
21 changes: 15 additions & 6 deletions bagua-core/bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::comm_ops::centralized_low_precision_synchronous::CentralizedLowPrecis
use crate::comm_ops::decentralized_full_precision_synchronous::{
DecentralizedFullPrecisionSynchronous, PeerSelectionMode,
};
use crate::comm_ops::python_ffi_op::PythonFFIOp;
use crate::comm_ops::CommOpTrait;
use crate::communicators::{BaguaCommunicator, BaguaSingleCommunicator};
use crate::resource_pool::{CudaMemory, CUDA_DEVICE_MEMORY_POOL};
Expand Down Expand Up @@ -586,7 +587,7 @@ pub struct BaguaBucketInner {
pub tensors: Vec<BaguaTensor>,
pub dtype: BaguaTensorDtype,
pub inplace: bool,
pub comm_op: Option<Arc<dyn CommOpTrait + Sync + Send>>,
pub comm_ops: Vec<Arc<dyn CommOpTrait + Sync + Send>>,
pub align_bytes: usize,
}

Expand Down Expand Up @@ -734,12 +735,14 @@ impl<'b> Drop for BaguaCommunicationTensor<'b> {
#[derive(Debug, Clone)]
pub struct BaguaBucket {
pub id: u64,
pub name: String,
pub inner: Arc<Mutex<BaguaBucketInner>>,
}

impl BaguaBucket {
pub fn new(
tensors: &[&BaguaTensor],
name: &str,
inplace: bool,
align_bytes: usize,
) -> Result<Self, BaguaCoreError> {
Expand Down Expand Up @@ -812,10 +815,11 @@ impl BaguaBucket {
let id = lazy_id::Id::lazy().get();
Ok(Self {
id,
name: name.to_owned(),
inner: Arc::new(Mutex::new(BaguaBucketInner {
inplace,
tensors: tensors.iter().map(|x| (**x).clone()).collect(),
comm_op: None,
comm_ops: vec![],
dtype: tensors.first().unwrap().inner.read().raw.dtype.clone(),
align_bytes,
})),
Expand All @@ -826,7 +830,7 @@ impl BaguaBucket {
self.inner.lock().tensors.clone()
}

pub fn set_decentralized_synchronous_op(
pub fn append_decentralized_synchronous_op(
&mut self,
communicator_internode: Option<&BaguaSingleCommunicator>,
communicator_intranode: Option<&BaguaSingleCommunicator>,
Expand Down Expand Up @@ -857,12 +861,17 @@ impl BaguaBucket {
}
},
};
self.inner.lock().comm_op = Some(comm_op);
self.inner.lock().comm_ops.push(comm_op);
}

pub fn append_python_op(&mut self, op: pyo3::Py<pyo3::PyAny>) {
let comm_op: Arc<dyn CommOpTrait + Send + Sync> = Arc::new(PythonFFIOp { py_callable: op });
self.inner.lock().comm_ops.push(comm_op);
}

/// this function will use communicator_internode to communicate.
/// if hierarchical = True, it will do hierarchical communicator, this requires intranode communicator on each node and inter node communicator on leader GPU. leader GPU will be the GPU whose communicator_intranode rank is 0
pub fn set_centralized_synchronous_op(
pub fn append_centralized_synchronous_op(
&mut self,
communicator_internode: Option<&BaguaSingleCommunicator>,
communicator_intranode: Option<&BaguaSingleCommunicator>,
Expand Down Expand Up @@ -893,7 +902,7 @@ impl BaguaBucket {
}
},
};
self.inner.lock().comm_op = Some(comm_op);
self.inner.lock().comm_ops.push(comm_op);
}

pub fn ready_for_comm(&self) -> bool {
Expand Down
36 changes: 21 additions & 15 deletions bagua-core/bagua-core-internal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub enum BaguaCoreError {
#[derive(Debug)]
pub struct BaguaScheduledCommOp {
pub bucket: Arc<BaguaBucket>,
pub op: Arc<dyn CommOpTrait + Send + Sync>,
pub ops: Vec<Arc<dyn CommOpTrait + Send + Sync>>,
pub event_channel: BaguaEventChannel,
}

Expand Down Expand Up @@ -125,14 +125,17 @@ pub struct BaguaCommBackend {
impl BaguaCommBackend {
pub fn schedule_comm(&self, bucket: Arc<BaguaBucket>) -> Result<(), BaguaCoreError> {
let event_channel = BaguaEventChannel::default();
self.channels.schedule_channel_sender.send(BaguaScheduledCommOp {
op: {
let guard = bucket.inner.lock();
guard.comm_op.clone().expect("bucket must have communication operator set before scheduled for communication")
},
bucket,
event_channel: event_channel.clone(),
}).map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?;
self.channels
.schedule_channel_sender
.send(BaguaScheduledCommOp {
ops: {
let guard = bucket.inner.lock();
guard.comm_ops.clone()
},
bucket,
event_channel: event_channel.clone(),
})
.map_err(|e| BaguaCoreError::InternalChannelError(format!("{:?}", e)))?;
Ok(self
.channels
.not_waited_events_sender
Expand Down Expand Up @@ -187,9 +190,12 @@ impl BaguaCommBackend {
"worker received scheduled communication operation {:?}",
comm_op
);
comm_op
.op
.execute_background_communication(comm_op.bucket.clone(), &channels_clone);
for op in &comm_op.ops {
op.execute_background_communication(
comm_op.bucket.clone(),
&channels_clone,
);
}
tracing::debug!("comm op executed: {:?}", comm_op);
comm_op.event_channel.finish();
tracing::debug!("comm op marked finished: {:?}", comm_op);
Expand Down Expand Up @@ -292,9 +298,9 @@ impl BaguaCommBackend {
match comm_op {
Ok(comm_op) => {
tracing::debug!("received post step communication operation {:?}", comm_op);
comm_op
.op
.execute_background_communication(comm_op.bucket.clone(), &self.channels);
for op in &comm_op.ops {
op.execute_background_communication(comm_op.bucket.clone(), &self.channels);
}
tracing::debug!("comm op executed: {:?}", comm_op);
comm_op.event_channel.finish();
tracing::debug!("comm op marked finished: {:?}", comm_op);
Expand Down
2 changes: 1 addition & 1 deletion bagua-core/bagua-core-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bagua-core-py"
version = "0.1.0"
version = "0.1.2"
authors = ["Xiangru Lian <admin@mail.xrlian.com>"]
edition = "2018"
publish = ["private"]
Expand Down
Loading

0 comments on commit e873538

Please sign in to comment.