Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make full precision decentralized op stateless #126

Merged
merged 23 commits into from
Jul 21, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from bagua.torch_api.communication import broadcast
from bagua.torch_api.env import get_rank
from typing import List
import torch

Expand Down Expand Up @@ -54,22 +56,47 @@ def hook(parameter_name, parameter):
def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
bagua_module._bagua_backend.wait_pending_comm_ops()
torch.cuda.synchronize()
bagua_module._bagua_backend.execute_post_backward_comm_ops()
bagua_module._bagua_backend.wait_pending_post_backward_comm_ops()

intra_comm = bagua_module._bagua_backend.intranode_communicator

def copyback_leader_fn(*unused):
for bucket in bagua_module.bagua_buckets:
bucket.backend_tensor.copy_(bucket._peer_weight)

if self.hierarchical:
broadcast(bucket.backend_tensor, 0, intra_comm)

def copyback_worker_fn(*unused):
for bucket in bagua_module.bagua_buckets:
if self.hierarchical:
broadcast(bucket.backend_tensor, 0, intra_comm)
else:
bucket.backend_tensor.copy_(bucket._peer_weight)

bagua_module._bagua_backend.schedule_python_op(
wangraying marked this conversation as resolved.
Show resolved Hide resolved
copyback_leader_fn if get_rank() == 0 else copyback_worker_fn
)
bagua_module._bagua_backend.wait_pending_comm_ops()

return hook

def _init_states(self, bucket: BaguaBucket):
weight_tensor = bucket.flattened_tensor()
bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
self._init_states(bucket)
torch.cuda.synchronize()
bucket.clear_ops()
bucket.append_decentralized_synchronous_op(
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
hierarchical=self.hierarchical,
peer_selection_mode=self.peer_selection_mode,
communication_interval=self.communication_interval,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
left_peer_weight=bucket._peer_weight,
wangraying marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down