Skip to content

Commit

Permalink
feat(python, core): support process group in with_bagua, support hi…
Browse files Browse the repository at this point in the history
…erarchical communication in bytegrad algorithm (#300)

BREAKING CHANGE: 1) `AlgorithmImpl` must pass a process group to its `__init__` method 2)    `decentralized_synchronous_op_copy_back_peer_weight` is now moved from `BaguaBucket` to decentralized synchronous op's `copy_back_peer_weight` method
  • Loading branch information
wangraying authored Oct 21, 2021
1 parent 65f8cd3 commit 4e1adda
Show file tree
Hide file tree
Showing 20 changed files with 526 additions and 226 deletions.
35 changes: 22 additions & 13 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm, AlgorithmImpl
from bagua.torch_api.communication import new_group, broadcast, barrier, _pg_group_ranks
from bagua.torch_api.communication import (
new_group,
broadcast,
barrier,
_pg_group_ranks,
BaguaProcessGroup,
)
from typing import List
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.env import get_rank
Expand All @@ -14,7 +20,7 @@
import concurrent


__all__ = ["AsyncModelAverageAlgorithm"]
__all__ = ["AsyncModelAverageAlgorithm", "AsyncModelAverageAlgorithmImpl"]


class _AsyncInternalState(IntEnum):
Expand All @@ -25,6 +31,7 @@ class _AsyncInternalState(IntEnum):
class AsyncModelAverageAlgorithmImpl(AlgorithmImpl):
def __init__(
self,
process_group: BaguaProcessGroup,
peer_selection_mode: str = "all",
sync_interval_ms: int = 500,
warmup_steps: int = 0,
Expand All @@ -44,13 +51,15 @@ def __init__(
and resume with `model.bagua_algorithm.resume(model)`.
Args:
process_group (BaguaProcessGroup): The process group to work on.
peer_selection_mode (str): The way how workers communicate with each other. Currently ``"all"`` is supported.
``"all"`` means all workers' weights are synchronized during each communication.
sync_interval_ms (int): Number of milliseconds between model synchronizations.
warmup_steps (int): Number of steps to warm up by doing gradient allreduce before doing asynchronous
model averaging. Use 0 to disable.
"""

super(AsyncModelAverageAlgorithmImpl, self).__init__(process_group)
self.peer_selection_mode = peer_selection_mode
self.sync_interval_ms = sync_interval_ms
self.step_id = 0
Expand All @@ -64,6 +73,11 @@ def __init__(
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.scheduled = False

process_ranks = list(_pg_group_ranks[self.process_group])
self.thread_group = new_group(
process_ranks, stream=torch.cuda.Stream(priority=-1)
)

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
if self.step_id < self.warmup_steps:
return super().tensors_to_buckets(tensors)
Expand Down Expand Up @@ -148,15 +162,9 @@ def init_operations(
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
group=bagua_module._bagua_process_group,
group=self.process_group,
)
else:
if not hasattr(self, "thread_group"):
process_ranks = list(_pg_group_ranks[bagua_module._bagua_process_group])
self.thread_group = new_group(
process_ranks, stream=torch.cuda.Stream(priority=-1)
)

async_op = bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode, group=self.thread_group
)
Expand Down Expand Up @@ -185,7 +193,7 @@ def _negotiate(self):
broadcast(
self.dummy_tensor,
src=0,
comm=self.thread_group.get_global_communicator(), # pytype: disable=attribute-error
comm=self.thread_group.get_global_communicator(),
)

return self.dummy_tensor.item()
Expand Down Expand Up @@ -223,7 +231,7 @@ def abort(self, bagua_module: BaguaModule):
"""

if self.scheduled:
barrier(comm=bagua_module._bagua_process_group.get_global_communicator())
barrier(comm=self.process_group.get_global_communicator())
self.abort_event.set()
self.future.result() # pytype: disable=attribute-error
self.scheduled = False
Expand All @@ -239,7 +247,7 @@ def resume(self, bagua_module: BaguaModule):
"""

if not self.scheduled and hasattr(self, "future"):
barrier(comm=bagua_module._bagua_process_group.get_global_communicator())
barrier(comm=self.process_group.get_global_communicator())
self.abort_event.clear()
self.future = self.executor.submit(self._run_async_loop, bagua_module)
self.scheduled = True
Expand Down Expand Up @@ -279,8 +287,9 @@ def __init__(
self.sync_interval_ms = sync_interval_ms
self.warmup_steps = warmup_steps

def reify(self) -> AsyncModelAverageAlgorithmImpl:
def reify(self, process_group: BaguaProcessGroup) -> AsyncModelAverageAlgorithmImpl:
return AsyncModelAverageAlgorithmImpl(
process_group,
peer_selection_mode=self.peer_selection_mode,
sync_interval_ms=self.sync_interval_ms,
warmup_steps=self.warmup_steps,
Expand Down
14 changes: 12 additions & 2 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.communication import BaguaProcessGroup
from typing import List
import torch

Expand All @@ -10,9 +11,12 @@ class Algorithm:
This is the base class that all Bagua algorithms inherit.
"""

def reify(self):
def reify(self, process_group: BaguaProcessGroup):
"""
Reify an algorithm instance.
Create an algorithm instance.
Args:
process_group: The process group to work on.
"""
pass

Expand All @@ -23,8 +27,14 @@ class AlgorithmImpl:
It provides methods that can be override to implement different kinds of
distributed algorithms.
Args:
process_group: The process group to work on.
"""

def __init__(self, process_group: BaguaProcessGroup):
self.process_group = process_group

def need_reset(self) -> bool:
"""
Returns:
Expand Down
28 changes: 22 additions & 6 deletions bagua/torch_api/algorithms/bytegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,30 @@
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm, AlgorithmImpl
from bagua.torch_api import get_world_size
from bagua.torch_api.communication import BaguaProcessGroup
from typing import List


class ByteGradAlgorithmImpl(AlgorithmImpl):
def __init__(self, average: bool = True):
def __init__(
self,
process_group: BaguaProcessGroup,
hierarchical: bool = True,
average: bool = True,
):
"""
Implementation of the
`ByteGrad <https://tutorials.baguasys.com/algorithms/bytegrad>`_
algorithm.
Args:
process_group (BaguaProcessGroup): The process group to work on.
hierarchical (bool): Enable hierarchical communication.
average (bool): If ``True``, the gradients on each worker are averaged.
Otherwise, they are summed.
"""
super(ByteGradAlgorithmImpl, self).__init__(process_group)
self.hierarchical = hierarchical
self.average = average

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
Expand All @@ -37,7 +46,10 @@ def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBuck
bagua_buckets = []
for idx, bucket in enumerate(tensors):
bagua_bucket = BaguaBucket(
bucket, flatten=True, name=str(idx), alignment=get_world_size()
bucket,
flatten=True,
name=str(idx),
alignment=self.process_group.get_global_communicator().nranks(),
)
bagua_buckets.append(bagua_bucket)
return bagua_buckets
Expand All @@ -49,7 +61,7 @@ def init_operations(
):
bucket.clear_ops()
bucket.append_centralized_synchronous_op(
hierarchical=True,
hierarchical=self.hierarchical,
average=self.average,
scattergather=True,
compression="MinMaxUInt8",
Expand All @@ -58,19 +70,23 @@ def init_operations(


class ByteGradAlgorithm(Algorithm):
def __init__(self, average: bool = True):
def __init__(self, hierarchical: bool = True, average: bool = True):
"""
Create an instance of the
`ByteGrad <https://tutorials.baguasys.com/algorithms/bytegrad>`_
algorithm.
Args:
hierarchical (bool): Enable hierarchical communication.
average (bool): If ``True``, the gradients on each worker are averaged.
Otherwise, they are summed.
"""
self.hierarchical = hierarchical
self.average = average

def reify(self) -> ByteGradAlgorithmImpl:
def reify(self, process_group: BaguaProcessGroup) -> ByteGradAlgorithmImpl:
return ByteGradAlgorithmImpl(
process_group,
hierarchical=self.hierarchical,
average=self.average,
)
36 changes: 28 additions & 8 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm, AlgorithmImpl
from bagua.torch_api.communication import BaguaProcessGroup
from typing import List
import torch


class DecentralizedAlgorithmImpl(AlgorithmImpl):
def __init__(
self,
process_group: BaguaProcessGroup,
hierarchical: bool = True,
peer_selection_mode: str = "all",
communication_interval: int = 1,
Expand All @@ -20,16 +22,19 @@ def __init__(
algorithm.
Args:
process_group (BaguaProcessGroup): The process group to work on.
hierarchical (bool): Enable hierarchical communication.
peer_selection_mode (str): Can be ``"all"`` or ``"shift_one"``. ``"all"`` means all workers'
weights are averaged in each communication step. ``"shift_one"`` means each worker
selects a different peer to do weights average in each communication step.
communication_interval (int): Number of iterations between two communication steps.
"""
super(DecentralizedAlgorithmImpl, self).__init__(process_group)
self.hierarchical = hierarchical
self.peer_selection_mode = peer_selection_mode
self.communication_interval = communication_interval
self.cuda_event = torch.cuda.Event()

def _should_communicate(self, bagua_module: BaguaModule) -> bool:
cur_step = bagua_module.bagua_train_step_counter - 1
Expand Down Expand Up @@ -70,9 +75,12 @@ def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
if self._should_communicate(bagua_module):
bagua_module._bagua_backend.wait_pending_comm_ops()

torch.cuda.current_stream().record_event(self.cuda_event)
self.cuda_event.synchronize()
for bucket in bagua_module.bagua_buckets:
bucket.decentralized_synchronous_op_copy_back_peer_weight(
hierarchical=self.hierarchical, peer_weight=bucket._peer_weight
bucket._decentralized_op.copy_back_peer_weight(
bucket.backend_bucket
)

return hook
Expand All @@ -89,25 +97,33 @@ def init_operations(
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
decentralized_op = bucket.append_decentralized_synchronous_op(
peer_weight=bucket._peer_weight,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
group=bagua_module._bagua_process_group,
group=self.process_group,
)
bucket._decentralized_op = decentralized_op


class LowPrecisionDecentralizedAlgorithmImpl(AlgorithmImpl):
def __init__(self, hierarchical: bool = True, communication_interval: int = 1):
def __init__(
self,
process_group: BaguaProcessGroup,
hierarchical: bool = True,
communication_interval: int = 1,
):
"""
Implementation of the
`Low Precision Decentralized SGD <https://tutorials.baguasys.com/algorithms/low-precision-decentralized>`_
algorithm.
Args:
process_group (BaguaProcessGroup): The process group to work on.
hierarchical (bool): Enable hierarchical communication.
communication_interval (int): Number of iterations between two communication steps.
"""
super(LowPrecisionDecentralizedAlgorithmImpl, self).__init__(process_group)
self.hierarchical = hierarchical
self.communication_interval = communication_interval

Expand Down Expand Up @@ -188,7 +204,7 @@ def init_operations(
right_peer_weight=bucket._right_peer_weight,
hierarchical=self.hierarchical,
compression="MinMaxUInt8",
group=bagua_module._bagua_process_group,
group=self.process_group,
)


Expand Down Expand Up @@ -216,8 +232,9 @@ def __init__(
self.peer_selection_mode = peer_selection_mode
self.communication_interval = communication_interval

def reify(self) -> DecentralizedAlgorithmImpl:
def reify(self, process_group: BaguaProcessGroup) -> DecentralizedAlgorithmImpl:
return DecentralizedAlgorithmImpl(
process_group,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
communication_interval=self.communication_interval,
Expand All @@ -238,8 +255,11 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1):
self.hierarchical = hierarchical
self.communication_interval = communication_interval

def reify(self) -> LowPrecisionDecentralizedAlgorithmImpl:
def reify(
self, process_group: BaguaProcessGroup
) -> LowPrecisionDecentralizedAlgorithmImpl:
return LowPrecisionDecentralizedAlgorithmImpl(
process_group,
hierarchical=self.hierarchical,
communication_interval=self.communication_interval,
)
Loading

0 comments on commit 4e1adda

Please sign in to comment.