Skip to content

Commit

Permalink
feat: make full precision decentralized op stateless (#126)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `BaguaBucket.append_decentralized_synchronous_op` now only supports full precision decentralized communication.
  • Loading branch information
wangraying authored Jul 21, 2021
1 parent ef7399e commit 0c978e9
Show file tree
Hide file tree
Showing 5 changed files with 529 additions and 63 deletions.
51 changes: 34 additions & 17 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ def __init__(
weights are averaged in each communication step. "shift_one" means each worker
selects a different peer to do weights average in each communication step.
communication_interval (int): Number of iterations between two communication steps.
"""
self.hierarchical = hierarchical
self.peer_selection_mode = peer_selection_mode
self.communication_interval = communication_interval

def _should_communicate(self, bagua_module: BaguaModule) -> bool:
cur_step = bagua_module.bagua_train_step_counter - 1
return cur_step % self.communication_interval == 0

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
Expand All @@ -40,8 +45,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
for tensor in self.tensors:
tensor.bagua_mark_communication_ready()
if self._should_communicate(bagua_module):
for tensor in self.tensors:
tensor.bagua_mark_communication_ready()

return hook

Expand All @@ -53,23 +59,31 @@ def hook(parameter_name, parameter):

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
bagua_module._bagua_backend.wait_pending_comm_ops()
torch.cuda.synchronize()
bagua_module._bagua_backend.execute_post_backward_comm_ops()
bagua_module._bagua_backend.wait_pending_post_backward_comm_ops()
if self._should_communicate(bagua_module):
bagua_module._bagua_backend.wait_pending_comm_ops()
for bucket in bagua_module.bagua_buckets:
bucket.decentralized_synchronous_op_copy_back_peer_weight(
hierarchical=self.hierarchical, peer_weight=bucket._peer_weight
)

return hook

def _init_states(self, bucket: BaguaBucket):
weight_tensor = bucket.flattened_tensor()
bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
peer_weight=bucket._peer_weight,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
communication_interval=self.communication_interval,
)


Expand All @@ -87,6 +101,10 @@ def __init__(self, hierarchical: bool = True, communication_interval: int = 1):
self.hierarchical = hierarchical
self.communication_interval = communication_interval

def _should_communicate(self, bagua_module: BaguaModule) -> bool:
cur_step = bagua_module.bagua_train_step_counter - 1
return cur_step % self.communication_interval == 0

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
self.tensors = [
Expand Down Expand Up @@ -123,12 +141,13 @@ def hook():

def init_post_optimizer_step_hook(self, bagua_module: BaguaModule):
def hook(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
if param.is_bagua_tensor():
param.bagua_mark_communication_ready()
if self._should_communicate(bagua_module):
for group in optimizer.param_groups:
for param in group["params"]:
if param.is_bagua_tensor():
param.bagua_mark_communication_ready()

bagua_module._bagua_backend.wait_pending_comm_ops()
bagua_module._bagua_backend.wait_pending_comm_ops()

return hook

Expand All @@ -153,12 +172,10 @@ def init_operations(
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
hierarchical=self.hierarchical,
peer_selection_mode="ring",
communication_interval=self.communication_interval,
compression="MinMaxUInt8",
bucket.append_low_precision_decentralized_synchronous_op(
weight=bucket._weight,
left_peer_weight=bucket._left_peer_weight,
right_peer_weight=bucket._right_peer_weight,
hierarchical=self.hierarchical,
compression="MinMaxUInt8",
)
132 changes: 95 additions & 37 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import broadcast


class BaguaBucket:
Expand Down Expand Up @@ -203,71 +204,128 @@ def append_centralized_synchronous_op(

def append_decentralized_synchronous_op(
self,
peer_weight: BaguaTensor,
hierarchical: bool = True,
peer_selection_mode: str = "all",
communication_interval: int = 1,
compression: Optional[str] = None,
weight: Optional[BaguaTensor] = None,
left_peer_weight: Optional[BaguaTensor] = None,
right_peer_weight: Optional[BaguaTensor] = None,
) -> BaguaBucket:
"""
Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.
The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.
This operation is not inplace, which means the bucket weights is first copied to `peer_weight`, and the result of
decentralized averaging will be in `peer_weight`. To copy `peer_weight` back to `self`, call
:func:`decentralized_synchronous_op_copy_back_peer_weight`.
This operation will be executed by the Bagua backend in
the order they are appended when all the tensors within the bucket are marked ready.
Args:
peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, "ring" for
low precision decentralized operation. "all" means all workers' weights are averaged in each communication step.
"shift_one" means each worker selects a different peer to do weights average in each communication step.
"ring" means all workers are connected into a ring, and each worker communicate with its neighbors.
communication_interval (int): Number of iterations between two communication steps.
compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is
supported.
weight (BaguaTensor): Local model of current worker, a flattened tensor containing the same data as the local model
weights of current worker, required for low precision decentralized operation.
left_peer_weight (BaguaTensor): Model replica of current worker's connected left peer, a flattened tensor containing
the same data as model weights of left peer, required for low precision decentralized operation.
right_peer_weight (BaguaTensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`,
required for low precision decentralized operation.
peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers' weights are averaged
in each communication step. "shift_one" means each worker selects a different peer to do weights average
in each communication step.
Returns:
The bucket itself.
"""

if hierarchical:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
peer_weight=peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
peer_weight=peer_weight._bagua_backend_tensor,
)
return self

def decentralized_synchronous_op_copy_back_peer_weight(
self, peer_weight: BaguaTensor, hierarchical: bool = True
):
"""
Copy `peer_weight` back to bucket weights to end a decentralized synchronous operation.
See :func:`append_decentralized_synchronous_op` for more information.
Args:
peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high. Must be the same with `hierarchical` argument in
:func:`append_decentralized_synchronous_op`.
"""
intra_comm = self._bagua_backend.intranode_communicator
inter_comm = self._bagua_backend.internode_communicator

if not hierarchical or (inter_comm is not None):
self.backend_tensor.copy_(peer_weight)

if hierarchical:
broadcast(self.backend_tensor, 0, intra_comm)

def append_low_precision_decentralized_synchronous_op(
self,
weight: BaguaTensor,
left_peer_weight: BaguaTensor,
right_peer_weight: BaguaTensor,
hierarchical: bool = True,
compression: str = "MinMaxUInt8",
) -> BaguaBucket:
"""
Append a low precision decentralized synchronous operation to a bucket. It will compress the difference
of local models between two successive iterations and exchange them among workers.
The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.
Args:
weight (BaguaTensor): Model replica of current worker's local model. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
left_peer_weight (BaguaTensor): Model replica of current worker's left peer. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor,
then copy the initializing weights of current worker's left peer to the tensor.
right_peer_weight (BaguaTensor): Model replica of current worker's right peer. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
then copy the initializing weights of current worker's right peer to the tensor.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
compression (str): The way how tensors are compressed for communication. Currently "MinMaxUInt8" is supported.
Returns:
The bucket itself.
"""

if hierarchical:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
None,
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)

return self
Expand Down
10 changes: 4 additions & 6 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def with_bagua( # pytype: disable=module-attr
self, "_ddp_params_and_buffers_to_ignore"
): # for compatibility with PyTorch DDP
self.parameters_to_ignore.extend(self._ddp_params_and_buffers_to_ignore)

self.bagua_train_step_counter = 0
"""
Number of iterations in training mode.
Expand Down Expand Up @@ -271,12 +272,8 @@ def record_speed_metrics_event(self, _):
)

# get communicators
self._bagua_inter_node_communicator = (
self._bagua_backend.internode_communicator
)
self._bagua_intra_node_communicator = (
self._bagua_backend.intranode_communicator
)
self._bagua_inter_node_communicator = self._bagua_backend.internode_communicator
self._bagua_intra_node_communicator = self._bagua_backend.intranode_communicator
self._bagua_global_communicator = self._bagua_backend.global_communicator
self.bagua_communication_stream = self._bagua_backend.stream

Expand Down Expand Up @@ -388,6 +385,7 @@ def real_post_backward_hook(*unused):
def new_step_factory(optimizer):
def new_step(self, *args, **kwargs):
result = self._bagua_original_step(*args, **kwargs)

optimizer_hook(self)
return result

Expand Down
Loading

0 comments on commit 0c978e9

Please sign in to comment.