Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MegaBlocks go vroom on Hopper. #24

Merged
merged 9 commits into from
Sep 25, 2023
Merged

Make MegaBlocks go vroom on Hopper. #24

merged 9 commits into from
Sep 25, 2023

Conversation

tgale96
Copy link
Contributor

@tgale96 tgale96 commented Sep 23, 2023

Add grouped GEMM-based dMoE to work around Triton limitations on SM90. Guard turbo use to we do not need it installed if quantization is not enabled. Add layer-wise dMoE benchmarks.

After this PR, we recommend enabling grouped_mlp for SM90. grouped_mlp should be used only with expert model parallelism to keep per-device expert counts low, which is important for efficiency with the current cuBLAS-based grouped GEMM kernels.

@tgale96
Copy link
Contributor Author

tgale96 commented Sep 23, 2023

dMoE benchmarks on 8x H100 with 8-way EMP:

============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 2048
ffn_hidden_size = 2048
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 4.263ms, std time = 3.202ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 2048
ffn_hidden_size = 2048
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 3.605ms, std time = 3.911ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 2048
ffn_hidden_size = 2048
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 7.239ms, std time = 5.606ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 2048
ffn_hidden_size = 2048
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 6.690ms, std time = 4.307ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 2560
ffn_hidden_size = 2560
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 5.165ms, std time = 4.151ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 2560
ffn_hidden_size = 2560
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 4.092ms, std time = 3.154ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 2560
ffn_hidden_size = 2560
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 8.410ms, std time = 5.480ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 2560
ffn_hidden_size = 2560
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 7.575ms, std time = 4.554ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 4096
ffn_hidden_size = 4096
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 7.288ms, std time = 3.739ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 4096
ffn_hidden_size = 4096
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 5.638ms, std time = 3.959ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 4096
ffn_hidden_size = 4096
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 13.633ms, std time = 4.487ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 4096
ffn_hidden_size = 4096
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 10.527ms, std time = 3.780ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 5120
ffn_hidden_size = 5120
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 9.172ms, std time = 4.656ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 5120
ffn_hidden_size = 5120
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 7.209ms, std time = 4.374ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 5120
ffn_hidden_size = 5120
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 17.286ms, std time = 5.826ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 5120
ffn_hidden_size = 5120
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 12.779ms, std time = 5.501ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 7168
ffn_hidden_size = 7168
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 14.249ms, std time = 4.010ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 2
sequence_length = 2048
hidden_size = 7168
ffn_hidden_size = 7168
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 10.088ms, std time = 3.611ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 7168
ffn_hidden_size = 7168
num_experts = 32
top_k = 4
grouped_mlp = False
Results:
mean time = 28.500ms, std time = 6.034ms
============================================================
============================================================
dMoE (Fwd) Benchmark
Benchmark Parameters:
batch_size = 4
sequence_length = 2048
hidden_size = 7168
ffn_hidden_size = 7168
num_experts = 32
top_k = 4
grouped_mlp = True
Results:
mean time = 19.245ms, std time = 5.647ms
============================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant