Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support moe #208

Merged
merged 31 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
db9963f
feat(python): support moe
liuhatry Sep 17, 2021
d3651e1
update
liuhatry Sep 17, 2021
47dc14b
fix code stype
liuhatry Sep 18, 2021
258ff6f
Update examples/moe/main.py
liuhatry Sep 18, 2021
ce8ffe6
update
liuhatry Sep 18, 2021
755c674
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 18, 2021
076e8a6
update
liuhatry Sep 18, 2021
d62347a
update
liuhatry Sep 18, 2021
ebc9b30
update
liuhatry Sep 18, 2021
f65d8af
Update bagua/torch_api/distributed.py
liuhatry Sep 18, 2021
ab7770e
update
liuhatry Sep 18, 2021
acc26c9
Update examples/mnist/main.py
liuhatry Sep 18, 2021
6fda627
Update examples/mnist/main.py
liuhatry Sep 18, 2021
7867e5d
add moe test
liuhatry Sep 18, 2021
6a84c5e
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 18, 2021
2a6cb52
udpate
liuhatry Sep 18, 2021
8379d5a
update
liuhatry Sep 23, 2021
8160021
update
liuhatry Sep 23, 2021
28bc3e2
deepspeed moe
liuhatry Sep 23, 2021
edf7e2d
update
liuhatry Sep 23, 2021
0971cca
Update bagua/torch_api/moe/experts.py
liuhatry Sep 23, 2021
0e67467
update
liuhatry Sep 23, 2021
2c2cbd0
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 23, 2021
7399925
update
liuhatry Sep 23, 2021
0a8b612
update
liuhatry Sep 23, 2021
a1d7162
update
liuhatry Sep 23, 2021
ab404ff
update
liuhatry Sep 23, 2021
04658ce
update
liuhatry Sep 27, 2021
0fc5ad1
update
liuhatry Sep 27, 2021
e50fdfc
update
liuhatry Sep 27, 2021
013a3c3
Update requirements.txt
NOBLES5E Sep 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@
from . import contrib # noqa: F401
from . import communication # noqa: F401
from . import algorithms # noqa: F401
from . import moe # noqa: E402,F401
5 changes: 3 additions & 2 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def bagua_build_params(self) -> List[Tuple[str, torch.nn.Parameter]]:
# single-process multi device case, where it accesses replicated
# parameters through _former_parameters.
for param_name, param in module.named_parameters(recurse=False)
if param.requires_grad
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
if param.requires_grad and
f"{module_name}.{param_name}" not in self.parameters_to_ignore and
(not getattr(param, "expert", False))
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
]
]

Expand Down
6 changes: 6 additions & 0 deletions bagua/torch_api/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import List

from .moe_layer import MOELayer # noqa: F401
from .top2gate import Top2Gate # noqa: F401

__all__: List[str] = []
96 changes: 96 additions & 0 deletions bagua/torch_api/moe/moe_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# The file has been adapted from:
# https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
# Git commit hash: 180ab8c8464d1c2a22556df9923696bbc2c92076
# We retain the following license from the original files:

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast

import torch
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module, ModuleList

if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module

# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.


# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
return output

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))


class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::

gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l_aux = moe.l_aux

.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf

Args:
gate: gate network
expert: expert network
group: group to use for all-to-all communication
"""

def __init__(self, gate: Module, experts: Union[Module, ModuleList], group: Optional[Any] = None) -> None:
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.gate = gate
if type(experts) == ModuleList:
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.group = group if group is not None else dist.group.WORLD
for expert in self.experts:
for p in experts.parameters():
p.expert = True # type: ignore
self.world_size = dist.get_world_size(self.group)
self.num_local_experts = len(self.experts)

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
# assert len(input) == 1, "only single input Tensor supported"
# assert len(input[0].shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
# assert input[0].shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"

# Implement Algorithm 2 from GShard paper.
d_model = input[0].shape[-1]
# Reshape into S tokens by dropping sequence dimension.
reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatch_mask = self.gate(reshaped_input)
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), reshaped_input)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
dispatched_input = _AllToAll.apply(self.group, dispatched_input)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, d_model)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1)
expert_output = _AllToAll.apply(self.group, expert_output)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.world_size * self.num_local_experts, -1, d_model)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
combined_output = torch.einsum("sec,ecm->sm", combine_weights, expert_output)
return combined_output.reshape(input[0].shape)
127 changes: 127 additions & 0 deletions bagua/torch_api/moe/top2gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# The file has been adapted from:
# https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
# Git commit hash: 180ab8c8464d1c2a22556df9923696bbc2c92076
# We retain the following license from the original files:

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, Tuple

import torch
from torch import Tensor
import torch.nn.functional as F

gumbel_map: Dict[torch.device, Callable] = {}


def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)


def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor:
"""Workaround for https://github.com/pytorch/pytorch/issues/55579"""
assert num_classes > 0, "num_classes must be a positive integer"
ret = torch.zeros(tensor.shape + (num_classes,), device=tensor.device, dtype=tensor.dtype)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
ret.scatter_(-1, tensor.unsqueeze(-1), 1)
return ret


def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# NOTE(msb) softmax requires FP32: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/
gates = F.softmax(logits, dim=1, dtype=torch.float)

# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# capacity = 2S/E
capacity = 2 * num_tokens // num_experts
# assert num_tokens % num_experts == 0

# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
mask1 = one_hot(indices1_s, num_classes=num_experts)

# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = one_hot(indices2_s, num_classes=num_experts)

# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce)

# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)

# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)

# Normalize gate probabilities
gates1_s = (gates * mask1).sum(dim=1) # einsum("se,se->s")
gates2_s = (gates * mask2).sum(dim=1) # einsum("se,se->s")
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s

# Calculate combine_weights and dispatch_mask
gates1 = gates1_s.unsqueeze(-1) * mask1 # einsum("s,se->se")
gates2 = gates2_s.unsqueeze(-1) * mask2 # einsum("s,se->se")
locations1_sc = one_hot(locations1_s, num_classes=capacity)
locations2_sc = one_hot(locations2_s, num_classes=capacity)
combine1_sec = gates1.unsqueeze(2) * locations1_sc.unsqueeze(1) # einsum("se,sc->sec")
combine2_sec = gates2.unsqueeze(2) * locations2_sc.unsqueeze(1) # einsum("se,sc->sec")
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux.to(logits.dtype), combine_weights.to(logits.dtype), dispatch_mask


class Top2Gate(torch.nn.Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::

gate = Top2Gate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)

.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf

Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""

wg: torch.nn.Linear

def __init__(self, model_dim: int, num_experts: int,) -> None:
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)

def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
logits = self.wg(input)
return top2gating(logits)
5 changes: 5 additions & 0 deletions examples/moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Use the following script to start training locally with 2 gpus:

```bash
python3 -m bagua.distributed.launch --nproc_per_node=2 main.py --algorithm gradient_allreduce
```
Loading