Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python, core): support mutable bucket tensors #271

Merged
merged 38 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0eb6994
reset tensorpy
wangraying Oct 8, 2021
83da89c
refactor
wangraying Oct 8, 2021
2c53d8f
make it ok for allreduce
wangraying Oct 9, 2021
9351840
update
wangraying Oct 9, 2021
4957026
add tests
wangraying Oct 9, 2021
b771b48
use getter closure and setter closure
wangraying Oct 11, 2021
cf73707
update
wangraying Oct 11, 2021
0222d93
tmp save
wangraying Oct 12, 2021
a060944
Merge branch 'master' into bucket-tensor
wangraying Oct 19, 2021
b007294
.
wangraying Oct 19, 2021
e94deca
f
wangraying Oct 20, 2021
fda2391
fix and add
wangraying Oct 20, 2021
de605fc
fix
wangraying Oct 20, 2021
b32eb61
support qadam
wangraying Oct 20, 2021
dde2d56
.
wangraying Oct 20, 2021
6a0a171
Merge branch 'master' into bucket-tensor
wangraying Oct 21, 2021
2f9acbd
close https://github.com/BaguaSys/bagua/issues/287
wangraying Oct 21, 2021
7f9c081
add doc
wangraying Oct 21, 2021
7e4e5da
rename
wangraying Oct 21, 2021
02f1699
add
wangraying Oct 21, 2021
419279c
remove fallback to python
wangraying Oct 22, 2021
76c54ad
add sanity check
wangraying Oct 22, 2021
51d2450
remove unwrap
wangraying Oct 22, 2021
6355b09
.
wangraying Oct 22, 2021
31db282
update doc
wangraying Oct 27, 2021
693de1c
update
wangraying Oct 27, 2021
440f832
.
wangraying Oct 27, 2021
f8c61c8
.
wangraying Oct 27, 2021
b45c201
.
wangraying Oct 27, 2021
999bd5b
.
wangraying Oct 27, 2021
3a496dc
Update tensor.py
NOBLES5E Oct 28, 2021
785aa73
Update tensor.py
NOBLES5E Oct 28, 2021
1aae792
Update tensor.py
NOBLES5E Oct 28, 2021
c03432c
Update tensor.py
NOBLES5E Oct 28, 2021
803b64f
update doc
wangraying Oct 28, 2021
c264345
Merge branch 'bucket-tensor' of https://github.com/BaguaSys/bagua int…
wangraying Oct 28, 2021
c634b0a
.
wangraying Oct 28, 2021
b98f94d
.
wangraying Oct 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions bagua/torch_api/algorithms/async_model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
tensors = []
for name, param in parameters.__reversed__():
if self.step_id < self.warmup_steps:
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)
else:
p = param.ensure_bagua_tensor(name, bagua_module.bagua_module_name)
tensors.append(p)
Expand Down Expand Up @@ -112,7 +114,7 @@ def hook(input):
def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if self.step_id <= self.warmup_steps:
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_mark_communication_ready()

return hook

Expand Down
16 changes: 10 additions & 6 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def init_tensors(self, bagua_module: BaguaModule) -> List[BaguaTensor]:
parameters = bagua_module.bagua_build_params()
tensors = []
for name, param in parameters.__reversed__():
grad = param.bagua_ensure_grad().ensure_bagua_tensor(
name, bagua_module.bagua_module_name
param = param.bagua_ensure_grad().ensure_bagua_tensor(
name,
bagua_module.bagua_module_name,
getter_closure=lambda param: param.grad,
setter_closure=lambda param, t: setattr(param, "grad", t),
)
param._bagua_grad = grad
tensors.append(grad)
tensors.append(param)

self._communication_tensor_names = set(name for name, _ in parameters)
assert len(self._communication_tensor_names) == len(
tensors
Expand Down Expand Up @@ -101,9 +104,10 @@ def init_backward_hook(self, bagua_module: BaguaModule):
def hook(parameter_name, parameter):
if parameter_name in self._communication_tensor_names:
assert (
parameter._bagua_grad.data_ptr() == parameter.grad.data_ptr()
parameter._bagua_backend_tensor.data_ptr()
== parameter.grad.data_ptr()
), "bagua grad data_ptr should match parameter grad"
parameter._bagua_grad.bagua_mark_communication_ready()
parameter.bagua_mark_communication_ready()

return hook

Expand Down
39 changes: 19 additions & 20 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def __init__(
# padding tensor must be of name bagua_padding_tensor, so that they are always marked as ready for communication in the backend
self.padding_tensor = torch.zeros(
padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
).to_bagua_tensor("bagua_padding_tensor_bucket_" + name)
).to_bagua_tensor(
"bagua_padding_tensor_bucket_" + name,
module_name=self.bagua_module_name,
)

self._all_tensors = (
self.tensors + [self.padding_tensor]
Expand All @@ -66,7 +69,8 @@ def __init__(
self._flatten_()

self.backend_bucket = B.BaguaBucketPy(
name, [tensor._bagua_backend_tensor for tensor in self._all_tensors]
name,
[tensor.bagua_backend_tensor() for tensor in self._all_tensors],
)

for tensor in self._all_tensors:
Expand All @@ -88,7 +92,9 @@ def flattened_tensor(self) -> BaguaTensor:
offset = 0
for tensor in self._all_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
flatten_tensor[
offset : offset + tensor.numel()
] = tensor.getter_closure().reshape(-1)
offset += tensor.numel()
return flatten_tensor

Expand All @@ -101,19 +107,12 @@ def _flatten_(self):

if len(self._all_tensors) == 0:
return
total_size = 0
for tensor in self._all_tensors:
total_size += tensor.numel()

flatten_tensor = torch.zeros(total_size, dtype=self._all_tensors[0].dtype).to(
self._all_tensors[0].device
)
flatten_tensor = self.flattened_tensor()
flatten_storage = flatten_tensor.storage()

offset = 0
for tensor in self._all_tensors:
# copy data
flatten_tensor[offset : offset + tensor.numel()] = tensor.data.reshape(-1)
tensor.bagua_set_storage(flatten_storage, offset)
offset += tensor.numel()

Expand All @@ -127,7 +126,7 @@ def check_flatten(self) -> bool:
Returns:
True if the bucket's tensors are contiguous in memory.
"""
return check_contiguous(self._all_tensors)
return check_contiguous([t.getter_closure() for t in self._all_tensors])

def append_python_op(self, python_function: Callable[[str], None]):
"""
Expand Down Expand Up @@ -227,15 +226,15 @@ def append_decentralized_synchronous_op(
self._bagua_backend.intranode_communicator,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)
else:
self.backend_bucket.append_decentralized_synchronous_op(
self._bagua_backend.global_communicator,
None,
hierarchical=hierarchical,
peer_selection_mode=peer_selection_mode,
peer_weight=peer_weight._bagua_backend_tensor,
peer_weight=peer_weight.bagua_backend_tensor(),
)

def decentralized_synchronous_op_copy_back_peer_weight(
Expand Down Expand Up @@ -299,9 +298,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)
else:
self.backend_bucket.append_low_precision_decentralized_synchronous_op(
Expand All @@ -310,9 +309,9 @@ def append_low_precision_decentralized_synchronous_op(
hierarchical=hierarchical,
peer_selection_mode="ring",
compression=compression,
weight=weight._bagua_backend_tensor,
left_peer_weight=left_peer_weight._bagua_backend_tensor,
right_peer_weight=right_peer_weight._bagua_backend_tensor,
weight=weight.bagua_backend_tensor(),
left_peer_weight=left_peer_weight.bagua_backend_tensor(),
right_peer_weight=right_peer_weight.bagua_backend_tensor(),
)

def append_asynchronous_model_average_op(self, peer_selection_mode: str):
Expand Down
82 changes: 67 additions & 15 deletions bagua/torch_api/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,24 @@ class BaguaTensor:
"""

def _bagua_sanity_check(self):
assert self._bagua_backend_tensor.data_ptr() == self.data_ptr()
assert self._bagua_backend_tensor.num_elements() == self.numel()
assert self._bagua_backend_tensor.num_elements_allocated() == self.numel()
assert self._bagua_backend_tensor.data_ptr() == self.getter_closure().data_ptr()
assert (
self._bagua_backend_tensor.num_elements() == self.getter_closure().numel()
)
assert (
self._bagua_backend_tensor.num_elements_allocated()
== self.getter_closure().numel()
)

def is_bagua_tensor(self) -> bool:
return hasattr(self, "_bagua_backend_tensor")

def ensure_bagua_tensor(
self, name: Optional[str] = None, module_name: Optional[str] = None
self,
name: Optional[str] = None,
module_name: Optional[str] = None,
getter_closure=None,
setter_closure=None,
):
"""
Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it.
Expand All @@ -42,25 +51,49 @@ def ensure_bagua_tensor(
assert (
self.bagua_tensor_name == name
), "assigning a different name to existing bagua tensor is forbidden"
return

if module_name is not None:
assert (
self.bagua_module_name == module_name
), "assigning a different module name to existing bagua tensor is forbidden"

self.bagua_tensor_name = name if name is not None else ""
self.bagua_module_name = module_name
self.bagua_backend = (
get_backend(self.bagua_module_name)
if self.bagua_module_name is not None
else None
)

# initialize backend tensor
if setter_closure is not None:
self.setter_closure = lambda t: setter_closure(self, t)
assert (
getter_closure is not None
), "must provide `setter_closure` when `getter_closure` is not None"
else:
self.setter_closure = None
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved

if getter_closure is not None:
self.getter_closure = lambda: getter_closure(self)
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
else:
self.getter_closure = lambda: self

self._bagua_backend_tensor = B.BaguaTensorPy(
name=self.bagua_tensor_name,
torch_tensor=self,
torch_tensor=self.getter_closure(),
)
self._bagua_sanity_check()

self._bagua_ready_event = torch.cuda.Event()
self._bagua_bucket = None
return self

def to_bagua_tensor(
self, name: Optional[str] = None, module_name: Optional[str] = None
self,
name: Optional[str] = None,
module_name: Optional[str] = None,
getter_closure=None,
setter_closure=None,
):
"""
Create a new Bagua tensor from a PyTorch tensor or parameter and return it.
Expand All @@ -77,7 +110,9 @@ def to_bagua_tensor(
The new Bagua tensor sharing the same storage with the original tensor.
"""
new_tensor = torch.Tensor(cdata=self._cdata)
return new_tensor.ensure_bagua_tensor(name, module_name)
return new_tensor.ensure_bagua_tensor(
name, module_name, getter_closure, setter_closure
)

def bagua_backend_tensor(self) -> B.BaguaTensorPy:
"""
Expand All @@ -92,12 +127,12 @@ def bagua_ensure_grad(self) -> torch.Tensor:
if not exist.
"""
if hasattr(self, "grad") and self.grad is not None:
return self.grad
return self
elif isinstance(self, torch.nn.Parameter):
with torch.no_grad():
t = torch.zeros_like(self.data)
self.grad = t
return self.grad
return self
else:
raise NotImplementedError

Expand All @@ -110,7 +145,7 @@ def bagua_mark_communication_ready(self):
self.bagua_backend is not None
), "tensor must be initialized with module name to call mark ready"
self.bagua_backend.mark_communication_ready(
self._bagua_backend_tensor,
self.bagua_backend_tensor(),
self._bagua_ready_event.cuda_event,
)

Expand All @@ -122,20 +157,37 @@ def bagua_mark_communication_ready_without_synchronization(self):
self.bagua_backend is not None
), "tensor must be initialized with module name to call mark ready"
self.bagua_backend.mark_communication_ready(
self._bagua_backend_tensor,
self.bagua_backend_tensor(),
0,
)

def bagua_set_storage(self, storage: torch.Storage, storage_offset: int = 0):
def bagua_reset_(self, tensor: torch.Tensor):
assert self.setter_closure is not None
self.setter_closure(tensor)
self._bagua_backend_tensor.reset(tensor)

def bagua_set_storage(
self,
storage: torch.Storage,
storage_offset: int = 0,
):
"""
Sets the underlying storage using an existing `torch.Storage <https://pytorch.org/docs/stable/storage.html?highlight=storage>`_.

Args:
storage: The storage to use.
storage_offset: The offset in the storage.
"""
if self.setter_closure is None:
# set directly
with torch.no_grad():
self.getter_closure().set_(storage, storage_offset, self.shape)
return

with torch.no_grad():
self.set_(storage, storage_offset, self.shape)
t = torch.zeros_like(self.getter_closure())
t.set_(storage, storage_offset, t.shape)
self.bagua_reset_(t)


_base = gorilla._get_base(BaguaTensor)
Expand Down
28 changes: 28 additions & 0 deletions rust/bagua-core/bagua-core-internal/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,34 @@ impl BaguaTensor {
})
}

pub fn reset_from_torch(
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
&self,
torch_cdata_ptr: u64,
dtype: BaguaTensorDtype,
) -> pyo3::PyResult<()> {
if dtype != self.inner.read().raw.dtype() {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"could not reset tensor from different tensor type",
));
}

let mut torch_tensor = TorchTensorRaw {
torch_tensor_cdata: torch_cdata_ptr,
python_fallback: false,
dtype,
};

let consistency = torch_tensor.check_consistency_with_python()?;
if !consistency {
tracing::warn!(
r#"PyTorch tensor memory layout inconsistent with latest PyTorch. Bagua will fallback to Python interface. This will degrade system performance. We suggest upgrading to latest PyTorch."#
)
}
torch_tensor.python_fallback = !consistency;
self.inner.write().raw = Box::new(torch_tensor);
return Ok(());
}

pub fn mark_comm_ready(&self, cuda_event_ptr: u64) {
if cuda_event_ptr == 0 {
tracing::info!("mark comm ready with an event 0, ignoring event");
Expand Down
29 changes: 29 additions & 0 deletions rust/bagua-core/bagua-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,35 @@ impl BaguaTensorPy {
})
}

pub fn reset(&self, torch_tensor: &PyAny) -> PyResult<()> {
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
// TODO: sanity check
let dtype = torch_tensor
.getattr("dtype")
.expect("must pass valid torch tensor")
.repr()?
.to_string();
let bagua_dtype = match dtype.as_str() {
"torch.float32" => BaguaTensorDtype::F32,
"torch.float16" => BaguaTensorDtype::F16,
"torch.int64" => BaguaTensorDtype::I64,
"torch.uint8" => BaguaTensorDtype::U8,
_ => {
return Err(PyRuntimeError::new_err(format!(
"unsupported tensor dtype {}",
dtype
)))
}
};

self.inner.reset_from_torch(
torch_tensor
.getattr("_cdata")
.expect("must pass valid torch tensor")
.extract()?,
bagua_dtype,
)
}

pub fn compress(&self, method: &str, n_chunks: usize, target_chunk: i32) -> Self {
Self {
inner: self.inner.compress(method, n_chunks, target_chunk),
Expand Down
Loading