Skip to content

Commit

Permalink
feat: add (scatter, gather, scatter_reduce) and all inplace version c…
Browse files Browse the repository at this point in the history
…ommunication primitives (#37)
  • Loading branch information
liuhatry authored Jul 15, 2021
1 parent 7f7773d commit 7c62156
Show file tree
Hide file tree
Showing 5 changed files with 524 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,21 @@ impl CommOpTrait for CentralizedFullPrecisionSynchronous {
true,
&mut |c, t| {
tracing::debug!("internode communication started");
let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id]
.try_pull(t.raw.num_elem_allocated * t.raw.dtype.bytes())
.expect("cannot allocate cuda memory");
let mut temp_tensor = BaguaTensorRaw {
ptr: temp_buf.ptr,
num_elem_allocated: t.raw.num_elem_allocated,
dtype: t.raw.dtype.clone(),
num_elem: t.raw.num_elem,
device_id: t.raw.device_id,
pool_allocations: vec![Arc::new(temp_buf)],
};
if self.scattergather {
tracing::debug!("start alltoall");
c.alltoall(&t.raw, &mut temp_tensor);
c.alltoall_inplace(&mut t.raw);
tracing::debug!("start reduce_sum");
if self.average {
temp_tensor.reduce_mean_inplace(c.nranks, c.rank, c.stream_ptr);
t.raw.reduce_mean_inplace(c.nranks, c.rank, c.stream_ptr);
} else {
temp_tensor.reduce_sum_inplace(c.nranks, c.rank, c.stream_ptr);
t.raw.reduce_sum_inplace(c.nranks, c.rank, c.stream_ptr);
}
tracing::debug!("start allgather");
c.allgather(&temp_tensor, &mut t.raw);
c.allgather_inplace(&mut t.raw);
tracing::debug!("internode communication done")
} else {
tracing::debug!("start allreduce");
c.allreduce(&mut t.raw, BaguaReductionOp::SUM);
c.allreduce_inplace(&mut t.raw, BaguaReductionOp::SUM);
tracing::debug!("internode communication done");
if self.average {
t.raw.divide_inplace(stream_ptr, c.nranks as f32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,17 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
true,
&mut |c, t| {
tracing::debug!("start compress");
let compressed_tensor = t
let mut compressed_tensor = t
.raw
.compress(&self.compression_method, c.nranks, c.stream_ptr, -1)
.expect("cannot compress tensor");
let temp_buf = CUDA_DEVICE_MEMORY_POOL[t.raw.device_id]
.try_pull(
compressed_tensor.num_elements_allocated()
* compressed_tensor.dtype().bytes(),
)
.expect("cannot allocate cuda memory");
let mut temp_tensor = BaguaTensorRaw {
ptr: temp_buf.ptr,
num_elem_allocated: compressed_tensor.num_elements_allocated(),
dtype: compressed_tensor.dtype().clone(),
num_elem: compressed_tensor.num_elements(),
device_id: compressed_tensor.device_id(),
pool_allocations: vec![Arc::new(temp_buf)],
};
tracing::debug!("start alltoall");
c.alltoall(compressed_tensor.as_ref(), &mut temp_tensor);
c.alltoall_inplace(compressed_tensor.as_mut());
tracing::debug!("start decompress");
t.raw.decompress_from(
&self.compression_method,
c.nranks,
&temp_tensor,
compressed_tensor.as_ref(),
c.stream_ptr,
);
tracing::debug!("start reduce_sum");
Expand All @@ -63,7 +49,7 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
t.raw.reduce_sum_inplace(c.nranks, c.rank, c.stream_ptr);
}
tracing::debug!("start compress");
let compressed_tensor = t
let mut compressed_tensor = t
.raw
.compress(
&self.compression_method,
Expand All @@ -73,12 +59,12 @@ impl CommOpTrait for CentralizedLowPrecisionSynchronous {
)
.expect("cannot compress tensor");
tracing::debug!("start allgather");
c.allgather(compressed_tensor.as_ref(), &mut temp_tensor);
c.allgather_inplace(compressed_tensor.as_mut());
tracing::debug!("start decompress");
t.raw.decompress_from(
&self.compression_method,
c.nranks,
&temp_tensor,
compressed_tensor.as_ref(),
c.stream_ptr,
);
tracing::debug!("internode communication done");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl CommOpTrait for DecentralizedFullPrecisionSynchronous {
if step % comm_interval == 0 {
peer_tensor.clone_from(&t.raw, c.stream_ptr);
let _guard = NCCLGroupGuard::new();
c.allreduce(&mut peer_tensor, BaguaReductionOp::SUM);
c.allreduce_inplace(&mut peer_tensor, BaguaReductionOp::SUM);
peer_tensor.divide_inplace(stream_ptr, c.nranks as f32);
}
}
Expand Down
Loading

0 comments on commit 7c62156

Please sign in to comment.