Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support nccl 2.10 ReduceOp.AVG #149

Merged
merged 5 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
alltoall_inplace,
reduce_scatter,
reduce_scatter_inplace,
ReduceOp,
)
from .distributed import BaguaModule # noqa: F401
from .tensor import BaguaTensor # noqa: F401
Expand Down
50 changes: 31 additions & 19 deletions bagua/torch_api/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,25 @@
get_default_bucket_size,
get_bagua_service_port,
)
from .utils import flatten, unflatten, to_bagua_reduce_op
from enum import IntEnum
from .utils import flatten, unflatten
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
from bagua.service.autotune_service import AutotuneClient
from functools import lru_cache

# must be consistent with Aluminum ReductionOperator: https://github.com/BaguaSys/Aluminum/blob/master/include/aluminum/base.hpp
class ReduceOp(IntEnum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BOR = 7
BAND = 8
BXOR = 9
AVG = 10


@lru_cache(maxsize=None)
def get_hyperparameters_service_client():
Expand Down Expand Up @@ -320,7 +332,7 @@ def reduce(
send_tensor,
recv_tensor,
dst,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
r"""Reduces the tensor across all processes.
Expand All @@ -331,7 +343,7 @@ def reduce(
send_tensor (torch.Tensor): Input of the collective.
recv_tensor (torch.Tensor): Output of the collective, must have the same size of send_tensor.
dst (int): Destination rank.
op (optional): one of the values from `torch.distributed.ReduceOp`
op (optional): one of the values from `bagua.ReduceOp`
enum. Specifies an operation used for element-wise reductions.
comm (B.BaguaSingleCommunicatorPy, optional): The bagua communicator to
work on. If None the global bagua communicator will be used.
Expand All @@ -356,14 +368,14 @@ def reduce(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
dst,
to_bagua_reduce_op(op),
int(op),
)

torch.cuda.synchronize()


def reduce_inplace(
tensor, dst, op=dist.ReduceOp.SUM, comm: B.BaguaSingleCommunicatorPy = None
tensor, dst, op=ReduceOp.SUM, comm: B.BaguaSingleCommunicatorPy = None
):
r"""The inplace version of reduce."""

Expand All @@ -377,15 +389,15 @@ def reduce_inplace(

with torch.cuda.stream(comm.cuda_stream):
comm.reduce_inplace(
tensor.to_bagua_tensor().bagua_backend_tensor(), dst, to_bagua_reduce_op(op)
tensor.to_bagua_tensor().bagua_backend_tensor(), dst, int(op)
)

torch.cuda.synchronize()


def allreduce_coalesced_inplace(
tensors,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
for tensor in tensors:
Expand All @@ -402,7 +414,7 @@ def allreduce_coalesced_inplace(
with torch.cuda.stream(comm.cuda_stream):
coalesced = flatten(tensors)
comm.allreduce_inplace(
coalesced.to_bagua_tensor("allreduce_coalesced"), to_bagua_reduce_op(op)
coalesced.to_bagua_tensor("allreduce_coalesced"), int(op)
)

for buf, synced in zip(tensors, unflatten(coalesced, tensors)):
Expand All @@ -415,7 +427,7 @@ def allreduce_coalesced_inplace(
def allreduce(
send_tensor,
recv_tensor,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
"""Reduces the tensor data across all machines in such a way that all get
Expand All @@ -425,7 +437,7 @@ def allreduce(
Args:
send_tensor (torch.Tensor): Input of the collective.
recv_tensor (torch.Tensor): Output of the collective, must have the same size of send_tensor.
op (optional): one of the values from `torch.distributed.ReduceOp` enum. Specifies an operation used for element-wise reductions.
op (optional): one of the values from `bagua.ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm (B.BaguaSingleCommunicatorPy, optional): The bagua communicator to
work on. If None the global bagua communicator will be used.
Defaults to None.
Expand Down Expand Up @@ -474,7 +486,7 @@ def allreduce(
comm.allreduce(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
to_bagua_reduce_op(op),
int(op),
)

# TODO: remove
Expand All @@ -483,7 +495,7 @@ def allreduce(

def allreduce_inplace(
tensor,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
"""The inplace version of allreduce."""
Expand All @@ -498,7 +510,7 @@ def allreduce_inplace(

with torch.cuda.stream(comm.cuda_stream):
comm.allreduce_inplace(
tensor.to_bagua_tensor().bagua_backend_tensor(), to_bagua_reduce_op(op)
tensor.to_bagua_tensor().bagua_backend_tensor(), int(op)
)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

torch.cuda.synchronize()
Expand Down Expand Up @@ -712,15 +724,15 @@ def scatter_inplace(
def reduce_scatter(
send_tensor,
recv_tensor,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
"""Reduces on send_tensor, then scatters send_tensor to all machines.

Args:
send_tensor (torch.Tensor): Input of the collective, must have size recv_tensor.size()*comm.nranks.
recv_tensor (torch.Tensor): Output of the collective.
op (optional): one of the values from `torch.distributed.ReduceOp` enum. Specifies an operation used for element-wise reductions.
op (optional): one of the values from `bagua.ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm (B.BaguaSingleCommunicatorPy, optional): The bagua communicator to
work on. If None the global bagua communicator will be used.
Defaults to None.
Expand All @@ -743,22 +755,22 @@ def reduce_scatter(
comm.reduce_scatter(
send_tensor.to_bagua_tensor().bagua_backend_tensor(),
recv_tensor.to_bagua_tensor().bagua_backend_tensor(),
to_bagua_reduce_op(op),
int(op),
)

torch.cuda.synchronize()


def reduce_scatter_inplace(
tensor,
op=dist.ReduceOp.SUM,
op=ReduceOp.SUM,
comm: B.BaguaSingleCommunicatorPy = None,
):
"""The inplace version of reduce_scatter.

Args:
send_tensor (torch.Tensor): Input and output of the collective, must satisfy: `tensor.size() % comm.nranks == 0`.
op (optional): one of the values from `torch.distributed.ReduceOp` enum. Specifies an operation used for element-wise reductions.
op (optional): one of the values from `bagua.ReduceOp` enum. Specifies an operation used for element-wise reductions.
comm (B.BaguaSingleCommunicatorPy, optional): The bagua communicator to
work on. If None the global bagua communicator will be used.
Defaults to None.
Expand All @@ -774,7 +786,7 @@ def reduce_scatter_inplace(

with torch.cuda.stream(comm.cuda_stream):
comm.reduce_scatter_inplace(
tensor.to_bagua_tensor().bagua_backend_tensor(), to_bagua_reduce_op(op)
tensor.to_bagua_tensor().bagua_backend_tensor(), int(op)
)

torch.cuda.synchronize()
Expand Down
19 changes: 0 additions & 19 deletions bagua/torch_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,6 @@ def to_bagua_datatype(datatype):
raise ValueError(f"unsupported data type {datatype}.")


def to_bagua_reduce_op(torch_reduce_op):
if torch_reduce_op == dist.ReduceOp.SUM:
return 0
elif torch_reduce_op == dist.ReduceOp.PRODUCT:
return 1
elif torch_reduce_op == dist.ReduceOp.MIN:
return 2
elif torch_reduce_op == dist.ReduceOp.MAX:
return 3
elif torch_reduce_op == dist.ReduceOp.BOR:
return 7
elif torch_reduce_op == dist.ReduceOp.BAND:
return 8
elif torch_reduce_op == dist.ReduceOp.BXOR:
return 9
else:
raise ValueError(f"unsupported reduce op {torch_reduce_op}.")


def average_by_removing_extreme_values(raw_score_list):
score_list = np.asarray(raw_score_list)

Expand Down