Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support process group #228

Merged
merged 21 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def init_operations(
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
group=bagua_module._bagua_process_group,
)
else:
async_op = bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
group=bagua_module._bagua_process_group,
)
bucket._async_op = async_op

Expand Down
1 change: 1 addition & 0 deletions bagua/torch_api/algorithms/bytegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def init_operations(
average=self.average,
scattergather=True,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def init_operations(
peer_weight=bucket._peer_weight,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
group=bagua_module._bagua_process_group,
)


Expand Down Expand Up @@ -187,4 +188,5 @@ def init_operations(
right_peer_weight=bucket._right_peer_weight,
hierarchical=self.hierarchical,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)
15 changes: 5 additions & 10 deletions bagua/torch_api/algorithms/gradient_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ def init_operations(
bucket: BaguaBucket,
):
bucket.clear_ops()
if self.hierarchical:
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
)
else:
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
)
bucket.append_centralized_synchronous_op(
hierarchical=self.hierarchical,
average=self.average,
group=bagua_module._bagua_process_group,
)
2 changes: 2 additions & 0 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def init_operations(
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
group=bagua_module._bagua_process_group,
)
else:

Expand All @@ -186,6 +187,7 @@ def calculate_momentum(*args):
average=True,
scattergather=True,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
)

def init_backward_hook(self, bagua_module: BaguaModule):
Expand Down
64 changes: 47 additions & 17 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#!/usr/bin/env python3

from __future__ import annotations
from bagua.torch_api.communication import get_backend
from bagua.torch_api.communication import get_backend, _get_default_group
from typing import List, Callable, Optional

import bagua_core as B
import torch

from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import broadcast
from bagua.torch_api.communication import (
broadcast,
BaguaProcessGroup,
_bagua_backend_comm,
_rank_not_in_comm,
)


class BaguaBucket:
Expand Down Expand Up @@ -157,6 +162,7 @@ def append_centralized_synchronous_op(
average: bool = True,
scattergather: bool = False,
compression: Optional[str] = None,
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a centralized synchronous operation to a bucket. It will sum or average the tensors in the bucket
Expand All @@ -174,19 +180,23 @@ def append_centralized_synchronous_op(
of allreduce. This is required for using compression.
compression: If not ``None``, the tensors will be compressed for communication. Currently ``"MinMaxUInt8"`` is
supported.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_centralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
average=average,
scattergather=scattergather,
compression=compression,
)
else:
self.backend_bucket.append_centralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
average=average,
Expand All @@ -199,6 +209,7 @@ def append_decentralized_synchronous_op(
peer_weight: BaguaTensor,
hierarchical: bool = True,
peer_selection_mode: str = "all",
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
Expand All @@ -219,27 +230,33 @@ def append_decentralized_synchronous_op(
peer_selection_mode (str): Can be ``"all"`` or ``"shift_one"``. ``"all"`` means all workers' weights are averaged
in each communication step. ``"shift_one"`` means each worker selects a different peer to do weights average
in each communication step.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
)

def decentralized_synchronous_op_copy_back_peer_weight(
self, peer_weight: BaguaTensor, hierarchical: bool = True
self,
peer_weight: BaguaTensor,
hierarchical: bool = True,
group: Optional[BaguaProcessGroup] = None,
):
"""
Copy :attr:`peer_weight` back to bucket weights to end a decentralized synchronous operation.
Expand All @@ -252,11 +269,15 @@ def decentralized_synchronous_op_copy_back_peer_weight(
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high. Must be the same with :attr:`hierarchical` argument in
:meth:`append_decentralized_synchronous_op`.
group: The process group to work on. If ``None``, the default process group will be used.
"""
intra_comm = self._bagua_backend.intranode_communicator
inter_comm = self._bagua_backend.internode_communicator
if group is None:
group = _get_default_group()

intra_comm = group.get_intra_node_communicator()
inter_comm = group.get_inter_node_communicator()

if not hierarchical or (inter_comm is not None):
if not hierarchical or not _rank_not_in_comm(inter_comm):
self.backend_tensor.copy_(peer_weight)

if hierarchical:
Expand All @@ -269,6 +290,7 @@ def append_low_precision_decentralized_synchronous_op(
right_peer_weight: BaguaTensor,
hierarchical: bool = True,
compression: str = "MinMaxUInt8",
group: Optional[BaguaProcessGroup] = None,
):
"""
Append a low precision decentralized synchronous operation to a bucket. It will compress the difference
Expand All @@ -290,12 +312,15 @@ def append_low_precision_decentralized_synchronous_op(
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
compression (str): The way how tensors are compressed for communication. Currently ``"MinMaxUInt8"`` is supported.
group: The process group to work on. If ``None``, the default process group will be used.
"""
if group is None:
group = _get_default_group()

if hierarchical:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
_bagua_backend_comm(group.get_inter_node_communicator()),
_bagua_backend_comm(group.get_intra_node_communicator()),
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
Expand All @@ -305,7 +330,7 @@ def append_low_precision_decentralized_synchronous_op(
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
hierarchical=hierarchical,
peer_selection_mode="ring",
Expand All @@ -315,7 +340,9 @@ def append_low_precision_decentralized_synchronous_op(
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
def append_asynchronous_model_average_op(
self, peer_selection_mode: str, group: Optional[BaguaProcessGroup] = None
):

"""
Append an asynchronous model average operation to a bucket. This operation will enable continuous
Expand All @@ -331,12 +358,15 @@ def append_asynchronous_model_average_op(self, peer_selection_mode: str):
Args:
peer_selection_mode (str): The way how workers communicate with each otehr. Currently ``"all"`` is supported.
``"all"`` means all workers' weights are averaged during each communication.
group: The process group to work on. If ``None``, the default process group will be used.
Returns:
The asynchronous model average operation itself.
"""
if group is None:
group = _get_default_group()

return self.backend_bucket.append_decentralized_asynchronous_op(
self._bagua_backend.global_communicator,
_bagua_backend_comm(group.get_global_communicator()),
None,
peer_selection_mode=peer_selection_mode,
torch_stream=torch.cuda.current_stream().cuda_stream,
Expand Down
Loading