Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fuse optimizer oom and make it stateless #207

Merged
merged 32 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
eef7b74
init add
wangraying Sep 17, 2021
fe66cfe
refactor move
wangraying Sep 17, 2021
7902908
patch optimizer
wangraying Sep 17, 2021
c3798a4
refactor
wangraying Sep 18, 2021
d25dcf7
add support for bagua
wangraying Sep 18, 2021
54744f5
fix
wangraying Sep 18, 2021
a1f13f1
.
wangraying Sep 18, 2021
0802583
add
wangraying Sep 23, 2021
11247e1
..
wangraying Sep 23, 2021
72b5744
fix
wangraying Sep 24, 2021
8b78658
Merge branch 'master' into fuse-opt
wangraying Sep 24, 2021
699846e
..
wangraying Sep 24, 2021
4d95fb8
add
wangraying Sep 24, 2021
ff05198
Merge branch 'fuse-opt' of https://github.com/BaguaSys/bagua into fus…
wangraying Sep 24, 2021
e426c82
doc
wangraying Sep 24, 2021
69a446c
..
wangraying Sep 24, 2021
c3e2e08
Merge branch 'fuse-opt' of https://github.com/BaguaSys/bagua into fus…
wangraying Sep 24, 2021
aa08916
fix api
wangraying Sep 24, 2021
9ff0873
add checking
wangraying Sep 24, 2021
de7df6e
feat(python, core): support mutable bucket tensors (#271)
wangraying Oct 28, 2021
b8ddea8
Merge branch 'master' into fuse-opt
wangraying Oct 28, 2021
7d9cf0a
update tmp
wangraying Oct 28, 2021
32d2981
add tests
wangraying Oct 28, 2021
94ee2a1
a runnable version except qadam
wangraying Oct 29, 2021
91cdab3
Merge branch 'master' of https://github.com/BaguaSys/bagua
wangraying Oct 29, 2021
bbd5c56
Merge branch 'master' into fuse-opt
wangraying Oct 29, 2021
9a1fccc
fix
wangraying Oct 29, 2021
731f616
stateless
wangraying Nov 1, 2021
f72624e
re
wangraying Nov 1, 2021
2fd9eb2
doc
wangraying Nov 1, 2021
dad451b
.
wangraying Nov 1, 2021
5add7f8
.
wangraying Nov 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,19 @@ def __init__(
process_ranks, stream=torch.cuda.Stream(priority=-1)
)

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
def tensors_to_buckets(
self, tensors: List[List[BaguaTensor]], do_flatten: bool
) -> List[BaguaBucket]:
# TODO: async algorithm conflict with fused optimizer, can only support flattened inplace bucket.
assert do_flatten, "async does not support `do_flatten=False` at present."
if self.step_id < self.warmup_steps:
return super().tensors_to_buckets(tensors)
return super().tensors_to_buckets(tensors, do_flatten)

all_tensors = []
for idx, bucket in enumerate(tensors):
all_tensors.extend(bucket)

bagua_bucket = BaguaBucket(all_tensors, flatten=True, name=str(0))
bagua_bucket = BaguaBucket(all_tensors, flatten=do_flatten, name=str(0))

return [bagua_bucket]

Expand Down
7 changes: 5 additions & 2 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
), "tensor names should be unique"
return tensors

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
def tensors_to_buckets(
self, tensors: List[List[BaguaTensor]], do_flatten: bool
) -> List[BaguaBucket]:
"""
Given the bucketing suggestion from Bagua, return the actual Bagua buckets.
The default implementation follows the suggestion to do the bucketing.
Expand All @@ -82,14 +84,15 @@ def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBuck
tensors: Bagua tensors grouped in different
lists, representing Bagua's suggestion on how to bucketing the
tensors.
do_flatten: Whether to flatten the Bagua buckets.

Returns:
A list of Bagua buckets.
"""
bagua_buckets = []
for idx, bucket in enumerate(tensors):
bagua_bucket = BaguaBucket(
bucket, flatten=True, name=str(idx)
bucket, flatten=do_flatten, name=str(idx)
) # TODO: check duplicated names
bagua_buckets.append(bagua_bucket)
return bagua_buckets
Expand Down
18 changes: 4 additions & 14 deletions bagua/torch_api/algorithms/bytegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,14 @@ def __init__(
self.hierarchical = hierarchical
self.average = average

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
"""
Given the bucketing suggestion from Bagua, return the actual Bagua buckets.
The default implementation follows the suggestion to do the bucketing.

Args:
tensors: Bagua tensors grouped in different
lists, representing Bagua's suggestion on how to bucketing the
tensors.

Returns:
A list of Bagua buckets.
"""
def tensors_to_buckets(
self, tensors: List[List[BaguaTensor]], do_flatten: bool
) -> List[BaguaBucket]:
bagua_buckets = []
for idx, bucket in enumerate(tensors):
bagua_bucket = BaguaBucket(
bucket,
flatten=True,
flatten=do_flatten,
name=str(idx),
alignment=self.process_group.get_global_communicator().nranks(),
)
Expand Down
6 changes: 4 additions & 2 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
]
return self.tensors

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
def tensors_to_buckets(
self, tensors: List[List[BaguaTensor]], do_flatten: bool
) -> List[BaguaBucket]:
all_tensors = []
for idx, bucket in enumerate(tensors):
all_tensors.extend(bucket)

bagua_bucket = BaguaBucket(all_tensors, flatten=True, name=str(0))
bagua_bucket = BaguaBucket(all_tensors, flatten=do_flatten, name=str(0))

return [bagua_bucket]

Expand Down
9 changes: 5 additions & 4 deletions bagua/torch_api/algorithms/q_adam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api import get_world_size
from bagua.torch_api.distributed import BaguaModule
from bagua.torch_api.algorithms import Algorithm, AlgorithmImpl
from bagua.torch_api.communication import BaguaProcessGroup
Expand Down Expand Up @@ -45,7 +44,7 @@ def __init__(
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(QAdamOptimizer, self).__init__(params, defaults)

# TODO: qadam optimizer maintain `step_id` in its state
self.step_id = 0
self.warmup_steps = warmup_steps

Expand Down Expand Up @@ -162,12 +161,14 @@ def set_momentum_fn(param, t):
tensor_groups.sort(key=lambda x: x._q_adam_idx)
return tensor_groups

def tensors_to_buckets(self, tensors: List[List[BaguaTensor]]) -> List[BaguaBucket]:
def tensors_to_buckets(
self, tensors: List[List[BaguaTensor]], do_flatten: bool
) -> List[BaguaBucket]:
bagua_buckets = []
for idx, bucket in enumerate(tensors):
bagua_bucket = BaguaBucket(
bucket,
flatten=True,
flatten=do_flatten,
name=str(idx),
alignment=self.process_group.get_global_communicator().nranks(),
)
Expand Down
25 changes: 5 additions & 20 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.utils import check_contiguous
from bagua.torch_api.utils import check_contiguous, get_flattened_tensor
from bagua.torch_api.communication import (
BaguaProcessGroup,
_bagua_backend_comm,
Expand Down Expand Up @@ -87,25 +87,10 @@ def flattened_tensor(self) -> torch.Tensor:
:attr:`self` tensors and padding tensor (if exists).
"""

all_registered_tensors = [
all_effective_tensors = [
tensor.bagua_getter_closure() for tensor in self._all_tensors
]
total_size = 0
for tensor in all_registered_tensors:
total_size += tensor.numel()

flatten_tensor = torch.zeros(
total_size,
dtype=all_registered_tensors[0].dtype,
device=all_registered_tensors[0].device,
)

offset = 0
for tensor in all_registered_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.reshape(-1)
offset += tensor.numel()
return flatten_tensor
return get_flattened_tensor(all_effective_tensors)

def _flatten_(self):
"""
Expand Down Expand Up @@ -372,7 +357,7 @@ def clear_ops(self) -> BaguaBucket:

def bytes(self) -> int:
"""Returns the total number of bytes occupied by the bucket."""
registered_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors]
effective_tensors = [tensor.bagua_getter_closure() for tensor in self.tensors]
return sum(
tensor.numel() * tensor.element_size() for tensor in registered_tensors
tensor.numel() * tensor.element_size() for tensor in effective_tensors
)
2 changes: 1 addition & 1 deletion bagua/torch_api/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fused_optimizer import FusedOptimizer # noqa: F401
from .fuse.optimizer import fuse_optimizer # noqa: F401
from .load_balancing_data_loader import ( # noqa: F401
LoadBalancingDistributedSampler,
LoadBalancingDistributedBatchSampler,
Expand Down
Empty file.
Loading