Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add compatibility for NCCL version under 2.10 #449

Merged
merged 4 commits into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 99 additions & 30 deletions rust/bagua-core/bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct BaguaCommunicatorInner {
pub nranks: usize,
pub device_id: usize,
pub aborted: Arc<AtomicBool>,
pub degraded: bool,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -47,6 +48,16 @@ impl BaguaSingleCommunicator {
};
tracing::debug!("al communicator initialized at {}", al_comm_ptr,);

let mut version: i32 = 0;
let version_ptr: *mut i32 = &mut version;
unsafe {
cpp::cpp!([version_ptr as "int *"]
{ NCCLCHECK(ncclGetVersion(version_ptr)); });
}

tracing::debug!("runtime nccl version: {}", version);
let degraded = version < 21000;

Self {
inner: Arc::new(BaguaCommunicatorInner {
stream_ptr,
Expand All @@ -55,6 +66,7 @@ impl BaguaSingleCommunicator {
nranks,
device_id,
aborted: Arc::new(AtomicBool::new(false)),
degraded: degraded,
}),
}
}
Expand Down Expand Up @@ -518,23 +530,32 @@ impl BaguaCommunicatorInner {
let count = send_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", dst as "int", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", dst as "int", count as "size_t", reduction_op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), dst, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), dst, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), dst, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), dst, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), dst, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), dst, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
if degraded {
recv_tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}

pub fn reduce_inplace(
Expand All @@ -548,23 +569,32 @@ impl BaguaCommunicatorInner {
let total_num_elem = tensor.num_elements_allocated();
let nccl_tensor_type = tensor.dtype().to_nccl_datatype();

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"]
cpp::cpp!([tensor_ptr as "void *", root_rank as "int", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", reduction_op as "uint8_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), root_rank, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), root_rank, *communicator_ptr);
Al::Reduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), root_rank, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
if degraded {
tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}

pub fn alltoall(&self, send_tensor: &dyn RawBaguaTensor, recv_tensor: &mut dyn RawBaguaTensor) {
Expand Down Expand Up @@ -963,23 +993,33 @@ impl BaguaCommunicatorInner {
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();
assert_eq!(count * self.nranks, send_tensor.num_elements_allocated());
assert_eq!(send_tensor.dtype(), recv_tensor.dtype());

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", reduction_op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
if degraded {
recv_tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}

pub fn reduce_scatter_inplace(&self, tensor: &mut dyn RawBaguaTensor, op: BaguaReductionOp) {
Expand All @@ -992,23 +1032,33 @@ impl BaguaCommunicatorInner {
let tensor_ptr = tensor.data_ptr();
let count = tensor.num_elements_allocated() / self.nranks;
let nccl_tensor_type = tensor.dtype().to_nccl_datatype();

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([tensor_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([tensor_ptr as "void *", count as "size_t", reduction_op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Reduce_scatter<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
if degraded {
tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}

pub fn barrier(&self) {
Expand Down Expand Up @@ -1039,23 +1089,32 @@ impl BaguaCommunicatorInner {
let count = send_tensor.num_elements_allocated();
let nccl_tensor_type = send_tensor.dtype().to_nccl_datatype();

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
cpp::cpp!([send_ptr as "void *", recv_ptr as "void *", count as "size_t", reduction_op as "uint8_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Allreduce<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<const float*>(send_ptr), static_cast<float*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Allreduce<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<const __half*>(send_ptr), static_cast<__half*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Allreduce<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<const unsigned char*>(send_ptr), static_cast<unsigned char*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Allreduce<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<const long long int*>(send_ptr), static_cast<long long int*>(recv_ptr), count, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
if degraded {
recv_tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}

pub fn allreduce_inplace(&self, tensor: &mut dyn RawBaguaTensor, op: BaguaReductionOp) {
Expand All @@ -1064,22 +1123,32 @@ impl BaguaCommunicatorInner {
let total_num_elem = tensor.num_elements_allocated();
let nccl_tensor_type = tensor.dtype().to_nccl_datatype();

let degraded = match op {
BaguaReductionOp::AVG => self.degraded,
_ => false,
};
let reduction_op = if degraded { BaguaReductionOp::SUM } else { op };

unsafe {
cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", op as "uint8_t"]
cpp::cpp!([tensor_ptr as "void *", total_num_elem as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t", reduction_op as "uint8_t"]
{
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Allreduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Allreduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Allreduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Allreduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(op), *communicator_ptr);
Al::Allreduce<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), total_num_elem, static_cast<Al::ReductionOperator>(reduction_op), *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}

if degraded {
tensor.divide_inplace(self.stream_ptr, self.nranks as f32);
}
}
}
Loading