Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support moe #208

Merged
merged 31 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
db9963f
feat(python): support moe
liuhatry Sep 17, 2021
d3651e1
update
liuhatry Sep 17, 2021
47dc14b
fix code stype
liuhatry Sep 18, 2021
258ff6f
Update examples/moe/main.py
liuhatry Sep 18, 2021
ce8ffe6
update
liuhatry Sep 18, 2021
755c674
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 18, 2021
076e8a6
update
liuhatry Sep 18, 2021
d62347a
update
liuhatry Sep 18, 2021
ebc9b30
update
liuhatry Sep 18, 2021
f65d8af
Update bagua/torch_api/distributed.py
liuhatry Sep 18, 2021
ab7770e
update
liuhatry Sep 18, 2021
acc26c9
Update examples/mnist/main.py
liuhatry Sep 18, 2021
6fda627
Update examples/mnist/main.py
liuhatry Sep 18, 2021
7867e5d
add moe test
liuhatry Sep 18, 2021
6a84c5e
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 18, 2021
2a6cb52
udpate
liuhatry Sep 18, 2021
8379d5a
update
liuhatry Sep 23, 2021
8160021
update
liuhatry Sep 23, 2021
28bc3e2
deepspeed moe
liuhatry Sep 23, 2021
edf7e2d
update
liuhatry Sep 23, 2021
0971cca
Update bagua/torch_api/moe/experts.py
liuhatry Sep 23, 2021
0e67467
update
liuhatry Sep 23, 2021
2c2cbd0
Merge branch 'megatron' of https://github.com/BaguaSys/bagua into meg…
liuhatry Sep 23, 2021
7399925
update
liuhatry Sep 23, 2021
0a8b612
update
liuhatry Sep 23, 2021
a1d7162
update
liuhatry Sep 23, 2021
ab404ff
update
liuhatry Sep 23, 2021
04658ce
update
liuhatry Sep 27, 2021
0fc5ad1
update
liuhatry Sep 27, 2021
e50fdfc
update
liuhatry Sep 27, 2021
013a3c3
Update requirements.txt
NOBLES5E Sep 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions .buildkite/scripts/benchmark_master.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,24 @@ echo "$BUILDKITE_PARALLEL_JOB_COUNT"

set -euox pipefail

# 0. install bagua
cp -a /upstream /workdir
export HOME=/workdir && cd $HOME && bash .buildkite/scripts/install_bagua.sh || exit 1

CHECK_RESULT=()

# 1. test communication_primitives api
echo "begin to test [communication_primitives]"
COMMUNICATION_SCRIPT="/workdir/examples/communication_primitives/main.py"
python -m bagua.distributed.launch \
--nnodes=2 \
--nproc_per_node 4 \
--node_rank=0 \
--master_addr="10.158.66.134" \
--master_port=1234 \
${COMMUNICATION_SCRIPT}


# 2. benchmark test with all communication algorithms
function check_benchmark_log {
logfile=$1
algorithm=$2
Expand Down Expand Up @@ -35,18 +50,7 @@ function check_benchmark_log {
fi
}

export HOME=/workdir && cd $HOME && bash .buildkite/scripts/install_bagua.sh || exit 1

echo "begin to test [communication_primitives]"
COMMUNICATION_SCRIPT="/workdir/examples/communication_primitives/main.py"
python -m bagua.distributed.launch \
--nnodes=2 \
--nproc_per_node 4 \
--node_rank=0 \
--master_addr="10.158.66.134" \
--master_port=1234 \
${COMMUNICATION_SCRIPT}

CHECK_RESULT=()
SYNTHETIC_SCRIPT="/workdir/examples/benchmark/synthetic_benchmark.py"
algorithms=(gradient_allreduce bytegrad decentralized low_precision_decentralized)
speeds=(185.0 180.0 150.0 115.0 170 0)
Expand Down Expand Up @@ -81,3 +85,35 @@ if [ ${#CHECK_RESULT[*]} -gt 0 ]; then
echo -e ${CHECK_RESULT[*]}
exit 1
fi

# 3. test moe
function check_moe_log {
logfile=$1
loss=$2

final_batch_loss=$(cat ${logfile} | grep "Loss" | tail -n 1 | awk '{print $NF}')

if [ $final_batch_loss == $loss ]; then
echo "Check moe success, final_batch_loss is equal."
else
result="Check moe fail, final_batch_loss["$final_batch_loss"] is not equal with "$loss"."
echo $result
exit 1
fi
}

MOE_SCRIPT="/workdir/examples/moe/mnist_main.py"
logfile=$(mktemp /tmp/bagua_moe_gradient_allreduce.XXXXXX.log)
CUDA_VISIBLE_DEVICES=0,1 python -m bagua.distributed.launch \
--nnodes=2 \
--nproc_per_node 2 \
--node_rank=0 \
--master_addr="10.158.66.134" \
--master_port=1234 \
${MOE_SCRIPT} \
--algorithm gradient_allreduce \
--epochs 5 \
--num-local-experts 2 \
--set-deterministic \
2>&1 | tee ${logfile}
check_moe_log ${logfile} 0.000293
22 changes: 21 additions & 1 deletion .buildkite/scripts/benchmark_worker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ echo "$BUILDKITE_PARALLEL_JOB_COUNT"

set -euox pipefail

# 0. install bagua
cp -a /upstream /workdir

export HOME=/workdir && cd $HOME && bash .buildkite/scripts/install_bagua.sh || exit 1


# 1. test communication_primitives api
echo "begin to test [communication_primitives]"
COMMUNICATION_SCRIPT="/workdir/examples/communication_primitives/main.py"
python -m bagua.distributed.launch \
Expand All @@ -19,6 +21,8 @@ python -m bagua.distributed.launch \
--master_port=1234 \
${COMMUNICATION_SCRIPT}


# 2. benchmark test with all communication algorithms
SYNTHETIC_SCRIPT="/workdir/examples/benchmark/synthetic_benchmark.py"
algorithms=(gradient_allreduce bytegrad decentralized low_precision_decentralized)
length=${#algorithms[@]}
Expand All @@ -41,3 +45,19 @@ do
--deterministic \
2>&1 | tee ${logfile}
done

# 3. test moe
MOE_SCRIPT="/workdir/examples/moe/mnist_main.py"
logfile=$(mktemp /tmp/bagua_moe_gradient_allreduce.XXXXXX.log)
CUDA_VISIBLE_DEVICES=0,1 python -m bagua.distributed.launch \
--nnodes=2 \
--nproc_per_node 2 \
--node_rank=1 \
--master_addr="10.158.66.134" \
--master_port=1234 \
${MOE_SCRIPT} \
--algorithm gradient_allreduce \
--epochs 5 \
--num-local-experts 2 \
--set-deterministic \
2>&1 | tee ${logfile}
1 change: 1 addition & 0 deletions bagua/torch_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@
from . import contrib # noqa: F401
from . import communication # noqa: F401
from . import algorithms # noqa: F401
from .model_parallel import moe # noqa: E402,F401
1 change: 1 addition & 0 deletions bagua/torch_api/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def bagua_build_params(self) -> List[Tuple[str, torch.nn.Parameter]]:
for param_name, param in module.named_parameters(recurse=False)
if param.requires_grad
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
and (not getattr(param, "expert", False))
]
]

Expand Down
Empty file.
1 change: 1 addition & 0 deletions bagua/torch_api/model_parallel/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layer import MoE # noqa: F401
41 changes: 41 additions & 0 deletions bagua/torch_api/model_parallel/moe/experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2021 Kuaishou AI Platform & DS3 Lab
#
# All rights reserved.
#
# The file has been adapted from DeepSpeed:
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
# Git commit hash: bff6126f0ddbd1a03da66867571ac87b11c21ac1
# We retain the following license from the original files:

# Copyright 2020 The Microsoft DeepSpeed Team

import torch
import copy


class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()

self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)]
)
self.num_local_experts = num_local_experts

# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.allreduce = False

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
expert_outputs += [out]

expert_output = torch.cat(expert_outputs, dim=1)
return expert_output
95 changes: 95 additions & 0 deletions bagua/torch_api/model_parallel/moe/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2021 Kuaishou AI Platform & DS3 Lab
#
# All rights reserved.
#
# The file has been adapted from DeepSpeed:
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/layer.py
# Git commit hash: bff6126f0ddbd1a03da66867571ac87b11c21ac1
# We retain the following license from the original files:

# Copyright 2020 The Microsoft DeepSpeed Team

import bagua.torch_api as bagua
import logging
import torch
import torch.distributed as dist

from .sharded_moe import MOELayer, TopKGate
from .experts import Experts
import typing


class MoE(torch.nn.Module):
def __init__(self,
hidden_size,
expert,
num_local_experts=1,
k=1,
output_dropout_prob=0.0,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None):
liuhatry marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize an MoE layer.

Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.

expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).

num_local_experts (int, optional): default=1, number of local experts per gpu.

k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.

output_dropout_prob (float, optional): default=0.0, output dropout probability.

capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.

eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.

min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.

noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
"""

super(MoE, self).__init__()

assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

self.num_experts = num_local_experts * bagua.get_world_size()
logging.info(f'num_experts: {self.num_experts} | num_local_experts: {num_local_experts} | world_size: {bagua.get_world_size()}')
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

experts = Experts(expert, num_local_experts)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size,
self.num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy),
experts,
num_local_experts,
group=dist.group.WORLD)
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

self.dropout = torch.nn.Dropout(output_dropout_prob)

def forward(self, hidden_states, used_token=None):
""" MoE forward
liuhatry marked this conversation as resolved.
Show resolved Hide resolved

Arguments:
hidden_states (Tensor): input to the layer
used_token (Tensor, optional): default: None, mask only used tokens

Returns:
A tuple including output, gate loss, and expert count.

* output (Tensor): output of the model

* l_aux (Tensor): gate loss value

* exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
output = self.dropout(output)
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
Loading