Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make full precision decentralized op stateless #126

Merged
merged 23 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
weights are averaged in each communication step. "shift_one" means each worker
selects a different peer to do weights average in each communication step.
communication_interval (int): Number of iterations between two communication steps.

"""
self.hierarchical = hierarchical
self.peer_selection_mode = peer_selection_mode
Expand All @@ -40,8 +41,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
for tensor in self.tensors:
tensor.bagua_mark_communication_ready()
if bagua_module.bagua_train_step_counter % self.communication_interval == 0:
wangraying marked this conversation as resolved.
Show resolved Hide resolved
for tensor in self.tensors:
tensor.bagua_mark_communication_ready()

return hook

Expand All @@ -53,23 +55,31 @@ def hook(parameter_name, parameter):

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
bagua_module._bagua_backend.wait_pending_comm_ops()
torch.cuda.synchronize()
bagua_module._bagua_backend.execute_post_backward_comm_ops()
bagua_module._bagua_backend.wait_pending_post_backward_comm_ops()
if bagua_module.bagua_train_step_counter % self.communication_interval == 0:
bagua_module._bagua_backend.wait_pending_comm_ops()
for bucket in bagua_module.bagua_buckets:
bucket.decentralized_synchronous_op_copy_back_peer_weight(
hierarchical=self.hierarchical, peer_weight=bucket._peer_weight
)

return hook

def _init_states(self, bucket: BaguaBucket):
weight_tensor = bucket.flattened_tensor()
bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
peer_weight=bucket._peer_weight,
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
communication_interval=self.communication_interval,
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -123,12 +133,13 @@ def hook():

def init_post_optimizer_step_hook(self, bagua_module: BaguaModule):
def hook(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
if param.is_bagua_tensor():
param.bagua_mark_communication_ready()
if bagua_module.bagua_train_step_counter % self.communication_interval == 0:
for group in optimizer.param_groups:
for param in group["params"]:
if param.is_bagua_tensor():
param.bagua_mark_communication_ready()

bagua_module._bagua_backend.wait_pending_comm_ops()
bagua_module._bagua_backend.wait_pending_comm_ops()

return hook

Expand All @@ -153,12 +164,10 @@ def init_operations(
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
hierarchical=self.hierarchical,
peer_selection_mode="ring",
communication_interval=self.communication_interval,
compression="MinMaxUInt8",
bucket.append_low_precision_decentralized_synchronous_op(
weight=bucket._weight,
left_peer_weight=bucket._left_peer_weight,
right_peer_weight=bucket._right_peer_weight,
hierarchical=self.hierarchical,
compression="MinMaxUInt8",
)
129 changes: 92 additions & 37 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.communication import broadcast


class BaguaBucket:
Expand Down Expand Up @@ -203,71 +204,125 @@ def append_centralized_synchronous_op(

def append_decentralized_synchronous_op(
self,
peer_weight: BaguaTensor,
hierarchical: bool = True,
peer_selection_mode: str = "all",
communication_interval: int = 1,
compression: Optional[str] = None,
weight: Optional[BaguaTensor] = None,
left_peer_weight: Optional[BaguaTensor] = None,
right_peer_weight: Optional[BaguaTensor] = None,
) -> BaguaBucket:
"""
Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.

The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.
This operation is not inplace, which means the bucket weights is first copied to `peer_weight`, and the result of
decentralized averaging will be in `peer_weight`. To copy `peer_weight` back to self, call
:func:`decentralized_synchronous_op_copy_back_peer_weight`.

This operation will be executed by the Bagua backend in
the order they are appended when all the tensors within the bucket are marked ready.

Args:
peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be the same size
with bucket weights, i.e. `self` tensors and padding tensor (if exists).
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
peer_selection_mode (str): Can be "all" or "shift_one" for full precision decentralized operation, "ring" for
low precision decentralized operation. "all" means all workers' weights are averaged in each communication step.
"shift_one" means each worker selects a different peer to do weights average in each communication step.
"ring" means all workers are connected into a ring, and each worker communicate with its neighbors.
communication_interval (int): Number of iterations between two communication steps.
compression: If not ``None``, the tensors will be compressed for communication. Currently "MinMaxUInt8" is
supported.
weight (BaguaTensor): Local model of current worker, a flattened tensor containing the same data as the local model
weights of current worker, required for low precision decentralized operation.
left_peer_weight (BaguaTensor): Model replica of current worker's connected left peer, a flattened tensor containing
the same data as model weights of left peer, required for low precision decentralized operation.
right_peer_weight (BaguaTensor): Model replica of current worker's connected right peer, similarly as `left_peer_weight`,
required for low precision decentralized operation.
peer_selection_mode (str): Can be "all" or "shift_one". "all" means all workers' weights are averaged
in each communication step. "shift_one" means each worker selects a different peer to do weights average
in each communication step.
Returns:
The bucket itself.
"""

if hierarchical:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
peer_weight=peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
communication_interval=communication_interval,
peer_weight=peer_weight._bagua_backend_tensor,
)
return self

def decentralized_synchronous_op_copy_back_peer_weight(
self, peer_weight: BaguaTensor, hierarchical: bool = True
):
"""
Copy `peer_weight` back to bucket weights to end a decentralized synchronous operation.
See :func:`append_decentralized_synchronous_op` for more information.

Args:
peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be the same size
with bucket weights, i.e. `self` tensors and padding tensor (if exists).
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
wangraying marked this conversation as resolved.
Show resolved Hide resolved
"""
intra_comm = self._bagua_backend.intranode_communicator
inter_comm = self._bagua_backend.internode_communicator

if not hierarchical or (inter_comm is not None):
self.backend_tensor.copy_(peer_weight)

if hierarchical:
broadcast(self.backend_tensor, 0, intra_comm)

def append_low_precision_decentralized_synchronous_op(
self,
weight: BaguaTensor,
left_peer_weight: BaguaTensor,
right_peer_weight: BaguaTensor,
hierarchical: bool = True,
compression: str = "MinMaxUInt8",
) -> BaguaBucket:
"""
Append a low precision decentralized synchronous operation to a bucket. It will compress the difference
of local models between two successive iterations and exchange them among workers.

The operations will be executed by the Bagua backend in the order they are appended
when all the tensors within the bucket are marked ready.

Args:
weight (BaguaTensor): Model replica of current worker's local model, should contain the same data as
initializing weights of current worker.
left_peer_weight (BaguaTensor): Model replica of current worker's connected left peer, should contain the
same data as initializing weights of current worker's left peer.
right_peer_weight (BaguaTensor): Model replica of current worker's connected right peer, should contain the
same data as initializing weights of current worker's right peer.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
compression: The way how tensors are compressed for communication. Currently "MinMaxUInt8" is supported.
Returns:
The bucket itself.
"""

if hierarchical:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.internode_communicator,
self._bagua_backend.intranode_communicator,
hierarchical=hierarchical,
peer_selection_mode="ring",
wangraying marked this conversation as resolved.
Show resolved Hide resolved
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
None,
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor if weight is not None else None,
left_peer_weight=left_peer_weight._bagua_backend_tensor
if left_peer_weight is not None
else None,
right_peer_weight=right_peer_weight._bagua_backend_tensor
if right_peer_weight is not None
else None,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
)

return self
Expand Down
12 changes: 5 additions & 7 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def with_bagua( # pytype: disable=module-attr
self, "_ddp_params_and_buffers_to_ignore"
): # for compatibility with PyTorch DDP
self.parameters_to_ignore.extend(self._ddp_params_and_buffers_to_ignore)
self.bagua_train_step_counter = 0

self.bagua_train_step_counter = -1
wangraying marked this conversation as resolved.
Show resolved Hide resolved
"""
Number of iterations in training mode.
"""
Expand Down Expand Up @@ -272,12 +273,8 @@ def record_speed_metrics_event(self, _):
)

# get communicators
self._bagua_inter_node_communicator = (
self._bagua_backend.internode_communicator
)
self._bagua_intra_node_communicator = (
self._bagua_backend.intranode_communicator
)
self._bagua_inter_node_communicator = self._bagua_backend.internode_communicator
self._bagua_intra_node_communicator = self._bagua_backend.intranode_communicator
self._bagua_global_communicator = self._bagua_backend.global_communicator
self.bagua_communication_stream = self._bagua_backend.stream

Expand Down Expand Up @@ -389,6 +386,7 @@ def real_post_backward_hook(*unused):
def new_step_factory(optimizer):
def new_step(self, *args, **kwargs):
result = self._bagua_original_step(*args, **kwargs)

optimizer_hook(self)
return result

Expand Down
Loading