Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): separate algorithm declaration and implementation #246

Merged
merged 51 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a36a6ea
refactor(python): support distributed algorithm instance reuse for mu…
Oct 6, 2021
9e71d33
Update bagua/torch_api/algorithms/async_model_average.py
Tengxu-Sun Oct 6, 2021
b29d911
Update bagua/torch_api/algorithms/bytegrad.py
Tengxu-Sun Oct 6, 2021
e5620ca
Update bagua/torch_api/algorithms/decentralized.py
Tengxu-Sun Oct 6, 2021
8334548
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
1f27cae
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
20ed8ca
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
37c5774
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
2c30d74
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
ea3dea6
Update bagua/torch_api/algorithms/decentralized.py
Tengxu-Sun Oct 6, 2021
5996c55
Update bagua/torch_api/algorithms/decentralized.py
Tengxu-Sun Oct 6, 2021
55388f5
Update bagua/torch_api/algorithms/q_adam.py
Tengxu-Sun Oct 6, 2021
b0ba5b2
Update bagua/torch_api/algorithms/gradient_allreduce.py
Tengxu-Sun Oct 6, 2021
ac230df
Update bagua/torch_api/algorithms_implementation/__init__.py
Tengxu-Sun Oct 6, 2021
87188b0
Update bagua/torch_api/algorithms_implementation/q_adam_implementatio…
Tengxu-Sun Oct 6, 2021
fdad05b
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
2650954
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
c661c32
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
4c1659d
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
9232149
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
620680d
Update tests/torch_api/test_multi_models.py
Tengxu-Sun Oct 6, 2021
6c00c7c
fix error
Oct 6, 2021
7e17a83
fix error
Oct 6, 2021
d952fd4
fix error
Oct 6, 2021
03125f9
fix error
Oct 6, 2021
d91aa57
fix error
Oct 6, 2021
1d7c9f5
fix error
Oct 6, 2021
69eba9e
fix error
Oct 6, 2021
0f35ede
fix error
Oct 6, 2021
fa33bed
fix error
Oct 6, 2021
e4c743c
fix error
Oct 7, 2021
f4c6f84
Update bagua/bagua_define.py
Tengxu-Sun Oct 7, 2021
778fdcb
Update bagua/bagua_define.py
Tengxu-Sun Oct 7, 2021
8930ebf
fix error of async
Oct 8, 2021
3eefb40
retrigger checks for lock error
Oct 8, 2021
ccfe687
merge the algorithm definition and implementation together
Oct 13, 2021
f8a12f6
Merge branch 'master' into algorithm_check
Oct 13, 2021
d2b3512
fix q_adam refactor error
Oct 14, 2021
359b82d
change docs
Oct 14, 2021
700494e
merge master and resolve confict file
Oct 17, 2021
3d21af4
fix error
Oct 17, 2021
e3a6156
retrigger checks
Oct 17, 2021
431bb3d
Update bagua/torch_api/algorithms/gradient_allreduce.py
NOBLES5E Oct 18, 2021
5d0ccdc
Update async_model_average.py
NOBLES5E Oct 18, 2021
79eaaad
delete underscore
Oct 18, 2021
cff7466
Update async_model_average.py
NOBLES5E Oct 18, 2021
3887036
Merge branch 'algorithm_check' of https://github.com/BaguaSys/bagua i…
Oct 18, 2021
ee60440
refactor the inheritance relationship
Oct 19, 2021
a852f08
rename the instance
Oct 19, 2021
de0c7e0
fix error
Oct 19, 2021
83c4f6f
fix docs
Oct 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bagua/torch_api/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3

from .base import Algorithm # noqa: F401
from . import bytegrad, decentralized, gradient_allreduce # noqa: F401
from . import q_adam, async_model_average # noqa: F401
198 changes: 10 additions & 188 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from typing import List
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.env import get_rank
Expand All @@ -10,17 +9,14 @@
import time
import torch
import logging
import concurrent
from bagua.torch_api.algorithms_implementation.async_model_average_implementation import (
AsyncModelAverageAlgorithm_Implementation,
)

__all__ = ["AsyncModelAverageAlgorithm"]


class _AsyncInternalState(IntEnum):
RESUME = 0
ABORT = 1


class AsyncModelAverageAlgorithm(Algorithm):
class AsyncModelAverageAlgorithm:
def __init__(
self,
peer_selection_mode: str = "all",
Expand Down Expand Up @@ -49,185 +45,11 @@ def __init__(

self.peer_selection_mode = peer_selection_mode
self.sync_interval_ms = sync_interval_ms
self.step_id = 0
self.warmup_steps = warmup_steps

self.cuda_event = torch.cuda.Event()

self.abort_event = threading.Event()
self.dummy_tensor = torch.Tensor([0]).byte()
self.main_group = torch.distributed.new_group(backend="gloo")
self.thread_group = torch.distributed.new_group(backend="gloo")

self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.scheduled = False

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
if self.step_id < self.warmup_steps:
return super().tensors_to_buckets(tensors)

all_tensors = []
for idx, bucket in enumerate(tensors):
all_tensors.extend(bucket)

bagua_bucket = BaguaBucket(all_tensors, flatten=True, name=str(0))

return [bagua_bucket]

def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
tensors = []
for name, param in parameters.__reversed__():
if self.step_id < self.warmup_steps:
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
)
param._bagua_grad = grad
tensors.append(grad)
else:
p = param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
tensors.append(p)

return tensors

def init_forward_pre_hook(self, bagua_module: BaguaModule):
def hook(input):
if (
self.step_id > self.warmup_steps
and self.sync_interval_ms > 0 # noqa: W503
):
self._lock_model(bagua_module)

if not hasattr(self, "future"):
self.future = self.executor.submit(
self._run_async_loop, bagua_module
)
self.scheduled = True
logging.debug(
"Process {} async communication started.".format(get_rank())
)

return hook

def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if self.step_id <= self.warmup_steps:
parameter._bagua_grad.bagua_mark_communication_ready()

return hook

def init_post_backward_hook(self, bagua_module: BaguaModule):
def hook():
if self.step_id <= self.warmup_steps:
bagua_module._bagua_backend.wait_pending_comm_ops()
else:
self._unlock_model(bagua_module)

return hook

def need_reset(self):
self.step_id += 1

if self.warmup_steps > 0 and self.step_id == self.warmup_steps + 1:
logging.info(f"Async model average starts from step {self.step_id}")
return True
else:
return False

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
bagua_module._bagua_backend.wait_pending_comm_ops()
bucket.clear_ops()

if self.step_id < self.warmup_steps:
bucket.append_centralized_synchronous_op(
hierarchical=False,
average=True,
)
else:
async_op = bucket.append_asynchronous_model_average_op(
peer_selection_mode=self.peer_selection_mode,
)
bucket._async_op = async_op

def _lock_model(self, bagua_module: BaguaModule):
torch.cuda.current_stream().record_event(self.cuda_event)
self.cuda_event.synchronize()

for bucket in bagua_module.bagua_buckets:
bucket._async_op.lock_weight()

def _unlock_model(self, bagua_module: BaguaModule):
torch.cuda.current_stream().record_event(self.cuda_event)
self.cuda_event.synchronize()

for bucket in bagua_module.bagua_buckets:
bucket._async_op.unlock_weight()

def _negotiate(self):
if self.abort_event.is_set():
self.dummy_tensor[0] = _AsyncInternalState.ABORT
else:
self.dummy_tensor[0] = _AsyncInternalState.RESUME

torch.distributed.broadcast(self.dummy_tensor, src=0, group=self.thread_group)
return self.dummy_tensor.item()

def _run_async_loop(self, bagua_module: BaguaModule):
comm_step = 0
while True:
state = self._negotiate()

if state == _AsyncInternalState.ABORT:
break

start_time = time.time()
for bucket in bagua_module.bagua_buckets:
for tensor in bucket.tensors:
tensor.bagua_mark_communication_ready_without_synchronization()

bagua_module._bagua_backend.wait_pending_comm_ops()
duration = (time.time() - start_time) * 1000

logging.debug(
"Process {} async communication cost {}ms, comm_step={}".format(
get_rank(), duration, comm_step
)
)
comm_step += 1
time.sleep(self.sync_interval_ms / 1000)

def abort(self, bagua_module: BaguaModule):
"""
Stop background asynchronous communications. Should be called after training.

Args:
bagua_module: A PyTorch module initialized by
:meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method.
"""

if self.scheduled:
torch.distributed.barrier(group=self.main_group)
self.abort_event.set()
self.future.result() # pytype: disable=attribute-error
self.scheduled = False
logging.debug("Process {} async communication aborted.".format(get_rank()))

def resume(self, bagua_module: BaguaModule):
"""
Resume aborted background asynchronous communications (see :meth:`abort`). Should be called before training.

Args:
bagua_module: A PyTorch module initialized by
:meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method.
"""

if not self.scheduled and hasattr(self, "future"):
torch.distributed.barrier(group=self.main_group)
self.abort_event.clear()
self.future = self.executor.submit(self._run_async_loop, bagua_module)
self.scheduled = True
logging.debug("Process {} async communication resumed.".format(get_rank()))
def reify(self) -> AsyncModelAverageAlgorithm_Implementation:
return AsyncModelAverageAlgorithm_Implementation(
peer_selection_mode=self.peer_selection_mode,
sync_interval_ms=self.sync_interval_ms,
warmup_steps=self.warmup_steps,
)
Tengxu-Sun marked this conversation as resolved.
Show resolved Hide resolved
39 changes: 5 additions & 34 deletions bagua/torch_api/algorithms/bytegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm
from bagua.torch_api import get_world_size
from typing import List
from bagua.torch_api.algorithms_implementation.bytegrad_implementation import ByteGradAlgorithm_Implementation
Tengxu-Sun marked this conversation as resolved.
Show resolved Hide resolved
Tengxu-Sun marked this conversation as resolved.
Show resolved Hide resolved


class ByteGradAlgorithm(Algorithm):
class ByteGradAlgorithm:
def __init__(self, average: bool = True):
"""
Create an instance of the
Expand All @@ -21,36 +21,7 @@ def __init__(self, average: bool = True):
"""
self.average = average

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
"""
Given the bucketing suggestion from Bagua, return the actual Bagua buckets.
The default implementation follows the suggestion to do the bucketing.

Args:
tensors: Bagua tensors grouped in different
lists, representing Bagua's suggestion on how to bucketing the
tensors.

Returns:
A list of Bagua buckets.
"""
bagua_buckets = []
for idx, bucket in enumerate(tensors):
bagua_bucket = BaguaBucket(
bucket, flatten=True, name=str(idx), alignment=get_world_size()
)
bagua_buckets.append(bagua_bucket)
return bagua_buckets

def init_operations(
self,
bagua_module: BaguaModule,
bucket: BaguaBucket,
):
bucket.clear_ops()
bucket.append_centralized_synchronous_op(
hierarchical=True,
def reify(self) -> ByteGradAlgorithm_Implementation:
return ByteGradAlgorithm_Implementation(
average=self.average,
scattergather=True,
compression="MinMaxUInt8",
)
)
Tengxu-Sun marked this conversation as resolved.
Show resolved Hide resolved
Tengxu-Sun marked this conversation as resolved.
Show resolved Hide resolved
Loading