Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python, core): support all_to_all_single #361

Merged
merged 36 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b171e26
support all_to_all_single
Youhe-Jiang Nov 4, 2021
f697dde
support all_to_all_single
Youhe-Jiang Nov 4, 2021
f7d8f80
feat: support all_to_all_single
Youhe-Jiang Nov 4, 2021
0b30efb
support all_to_all_single
Youhe-Jiang Nov 4, 2021
ed92d7d
support all_to_all_single
Youhe-Jiang Nov 4, 2021
a9a92d0
support all_to_all_single
Youhe-Jiang Nov 4, 2021
d18be1b
support all_to_all_single
Youhe-Jiang Nov 4, 2021
3ab9035
support all_to_all_single
Youhe-Jiang Nov 5, 2021
8d59621
support all_to_all_single
Youhe-Jiang Nov 5, 2021
4274b7f
support all_to_all_single
Youhe-Jiang Nov 5, 2021
4ef84a5
support all_to_all_single
Youhe-Jiang Nov 5, 2021
17d3b1e
support all_to_all_single
Youhe-Jiang Nov 5, 2021
85f21e3
support all_to_all_single
Youhe-Jiang Nov 5, 2021
64481b7
support all_to_all_single
Youhe-Jiang Nov 5, 2021
bf7e4ec
support
Youhe-Jiang Nov 5, 2021
b64e353
support
Youhe-Jiang Nov 5, 2021
16ddb4d
test
Youhe-Jiang Nov 5, 2021
9368dce
support
Youhe-Jiang Nov 5, 2021
6459b20
support
Youhe-Jiang Nov 5, 2021
2f92884
support
Youhe-Jiang Nov 5, 2021
ac7d9c3
support
Youhe-Jiang Nov 5, 2021
d981d18
support
Youhe-Jiang Nov 5, 2021
19a42b6
support
Youhe-Jiang Nov 5, 2021
ec2a289
support
Youhe-Jiang Nov 5, 2021
b5cce72
support
Youhe-Jiang Nov 5, 2021
d954e25
support
Youhe-Jiang Nov 5, 2021
80422d4
support alltoallsingle
Youhe-Jiang Nov 8, 2021
df57105
support alltoallsingle
Youhe-Jiang Nov 8, 2021
cb1dcfb
float type
Youhe-Jiang Nov 8, 2021
cad6c6b
float type
Youhe-Jiang Nov 8, 2021
752ab7e
int type
Youhe-Jiang Nov 8, 2021
25f9759
int type
Youhe-Jiang Nov 8, 2021
1f93167
int type
Youhe-Jiang Nov 8, 2021
4068226
int type
Youhe-Jiang Nov 8, 2021
b592958
int type
Youhe-Jiang Nov 8, 2021
5d5e2e4
alltoall_v_inplace test
Youhe-Jiang Nov 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
allgather_inplace,
alltoall,
alltoall_inplace,
alltoall_v,
alltoall_v_inplace,
reduce_scatter,
reduce_scatter_inplace,
ReduceOp,
Expand Down
76 changes: 76 additions & 0 deletions bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,82 @@ def alltoall_inplace(
comm.cuda_stream.synchronize()


def alltoall_v(
send_tensor: torch.Tensor,
send_counts: int,
send_displs: int,
recv_tensor: torch.Tensor,
recv_counts: int,
recv_displs: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""
Each process scatters :attr:`send_tensor` to all processes associated with the communicator and return the gathered
data in :attr:`recv_tensor`, each process may send a different amount of data and provide displacements for the input and output data.

Args:
send_tensor (torch.Tensor): Input of the collective, the size must be divisible by ``comm.nranks``.
send_counts: integer array equal to the group size specifying the number of elements to send to each processor.
send_displs: integer array (of length group size). Entry j specifies the displacement (relative to sendbuf from which to take the outgoing data destined for process j.
recv_tensor (torch.Tensor): Output of the collective, must have equal size with :attr:`send_tensor`.
recv_counts: integer array equal to the group size specifying the maximum number of elements that can be received from each processor.
recv_displs: integer array (of length group size). Entry i specifies the displacement (relative to recvbuf at which to place the incoming data from process i.
comm: A handle of the Bagua communicator to work on. By default, the global
communicator of the default process group will be used.
"""
if _rank_not_in_comm(comm):
return

assert send_tensor.device != torch.device(
"cpu"
), "send tensor must be CUDA and dense"
assert recv_tensor.device != torch.device(
"cpu"
), "recv tensor must be CUDA and dense"

if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()

event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)

with torch.cuda.stream(comm.cuda_stream):
comm.alltoall_v(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
send_counts,
send_displs,
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_counts,
recv_displs,
)

comm.cuda_stream.synchronize()


def alltoall_v_inplace(
tensor: torch.Tensor,
counts: int,
displs: int,
comm: Optional[B.BaguaSingleCommunicatorPy] = None,
):
"""The in-place version of :func:`alltoall_v`."""
if _rank_not_in_comm(comm):
return

assert tensor.device != torch.device("cpu"), "recv tensor must be CUDA and dense"

if comm is None or comm is CommMember.WORLD:
comm = _get_default_group().get_global_communicator()

event = torch.cuda.current_stream().record_event()
comm.cuda_stream.wait_event(event)

with torch.cuda.stream(comm.cuda_stream):
comm.alltoall_v_inplace(tensor.to_bagua_tensor().bagua_backend_tensor(), counts, displs)

comm.cuda_stream.synchronize()


def barrier(comm: Optional[B.BaguaSingleCommunicatorPy] = None):
"""
Synchronizes all processes.
Expand Down
32 changes: 32 additions & 0 deletions examples/communication_primitives/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,38 @@ def main():
a=recv_tensors, b=recv_tensor_bagua
)

# alltoall_v
send_tensors = torch.arange(16, dtype=torch.float32).cuda()
recv_tensors = torch.zeros(16, dtype=torch.float32).cuda()
recv_tensor_bagua = torch.zeros(16, dtype=torch.float32).cuda()
in_splits = [2, 2, 2, 2, 2, 2, 2, 2]
out_splits = [2, 2, 2, 2, 2, 2, 2, 2]
send_counts = [2, 2, 2, 2, 2, 2, 2, 2]
recv_counts = [2, 2, 2, 2, 2, 2, 2, 2]
send_sdispls = [0, 2, 4, 6, 8, 10, 12, 14, 16]
recv_sdispls = [0, 2, 4, 6, 8, 10, 12, 14, 16]
dist.all_to_all_single(recv_tensors, send_tensors, out_splits, in_splits)
bagua.alltoall_v(
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
send_tensors,
send_counts,
send_sdispls,
recv_tensor_bagua,
recv_counts,
recv_sdispls,
comm=comm,
)
bagua.alltoall_v_inplace(send_tensors, send_counts, send_sdispls)
assert torch.equal(
recv_tensors, recv_tensor_bagua
), "recv_tensors:{a}, recv_tensor_bagua:{b}".format(
a=recv_tensors, b=recv_tensor_bagua
)
assert torch.equal(
send_tensors, recv_tensor_bagua
), "recv_tensors:{a}, recv_tensor_bagua:{b}".format(
a=recv_tensors, b=recv_tensor_bagua
)


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions rust/bagua-core/bagua-core-internal/src/communicators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ impl BaguaSingleCommunicator {
);
}

pub fn alltoall_v_inplace(&self, tensor: &mut BaguaTensor, counts: &[usize], displs: &[usize]) {
self.inner
.alltoall_v_inplace(tensor.inner.write().raw.as_mut(), counts, displs);
}

pub fn allgather(&self, send_tensor: &BaguaTensor, recv_tensor: &mut BaguaTensor) {
self.inner.allgather(
send_tensor.inner.read().raw.as_ref(),
Expand Down Expand Up @@ -675,6 +680,42 @@ impl BaguaCommunicatorInner {
}
}

pub fn alltoall_v_inplace(
&self,
tensor: &dyn RawBaguaTensor,
counts: &[usize],
displs: &[usize],
) {
let communicator_ptr = self.comm_ptr;
let nranks = self.nranks;
let nccl_tensor_type = tensor.dtype().to_nccl_datatype();

let tensor_ptr = tensor.data_ptr();
let counts_ptr = counts.as_ptr();
let displs_ptr = displs.as_ptr();
unsafe {
cpp::cpp!([
tensor_ptr as "void *", counts_ptr as "size_t *", displs_ptr as "size_t *",
nranks as "size_t", communicator_ptr as "Al::NCCLCommunicator *", nccl_tensor_type as "ncclDataType_t"]
{
std::vector<size_t> counts(counts_ptr, counts_ptr + nranks);
std::vector<size_t> displs(displs_ptr, displs_ptr + nranks);
if (nccl_tensor_type == ncclDataType_t::ncclFloat32) {
Al::Alltoallv<Al::NCCLBackend>(static_cast<float*>(tensor_ptr), counts, displs, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclFloat16) {
Al::Alltoallv<Al::NCCLBackend>(static_cast<__half*>(tensor_ptr), counts, displs, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclUint8) {
Al::Alltoallv<Al::NCCLBackend>(static_cast<unsigned char*>(tensor_ptr), counts, displs, *communicator_ptr);
} else if (nccl_tensor_type == ncclDataType_t::ncclInt64) {
Al::Alltoallv<Al::NCCLBackend>(static_cast<long long int*>(tensor_ptr), counts, displs, *communicator_ptr);
} else {
fputs("unsupport tensor data type.\n", stderr);
abort();
}
});
}
}

pub fn send(&self, send_tensor: &dyn RawBaguaTensor, peer_rank: i32) {
let communicator_ptr = self.comm_ptr;
let tensor_ptr = send_tensor.data_ptr();
Expand Down
10 changes: 10 additions & 0 deletions rust/bagua-core/bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ impl BaguaSingleCommunicatorPy {
)
}

pub fn alltoall_v_inplace(
&self,
tensor: &mut BaguaTensorPy,
counts: Vec<usize>,
displs: Vec<usize>,
) {
self.inner
.alltoall_v_inplace(&mut tensor.inner, &counts, &displs)
}

pub fn allgather(&self, send_tensor: &BaguaTensorPy, recv_tensor: &mut BaguaTensorPy) {
self.inner
.allgather(&send_tensor.inner, &mut recv_tensor.inner)
Expand Down