Skip to content

Commit

Permalink
add deepseek2 support (#472)
Browse files Browse the repository at this point in the history
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com>
Co-authored-by: baishihao <baishihao@sensetime.com>
  • Loading branch information
5 people committed Jul 25, 2024
1 parent 59f30ec commit 27362fa
Show file tree
Hide file tree
Showing 19 changed files with 2,370 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__/
build
dist
*.egg-info
.idea
.idea
.vscode
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
- [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)
- [DeepSeek-V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)

> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
Expand All @@ -64,6 +66,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
> Phi-3 only supports Mini and Small.
> DeepSeek-V2-Lite and DeepSeek-V2 need to set the parameter '--data_type bfloat16'
## Get started

### Requirements
Expand Down
8 changes: 8 additions & 0 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch

from .mem_manager import MemoryManager


class Deepseek2MemoryManager(MemoryManager):
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
Empty file.
Empty file.
259 changes: 259 additions & 0 deletions lightllm/models/deepseek2/layer_infer/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from typing import List, Optional, Tuple, Type
import torch
import triton
import triton.language as tl


def ceil_div(a, b):
return (a + b - 1) // b


@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

start_idx = pid * tokens_per_thread

off_c = (pid + 1) * num_experts

for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)


@triton.jit
def moe_align_block_size_stage2(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)


@triton.jit
def moe_align_block_size_stage3(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)


@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)

for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)

start_idx = pid * tokens_per_thread
off_t = pid * num_experts

for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)


@torch.no_grad()
def moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros((num_experts + 1, num_experts), dtype=torch.int32, device="cuda")
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device="cuda")
tokens_per_thread = ceil_div(numel, num_experts)

moe_align_block_size_stage1[grid](
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
BLOCK_SIZE=num_experts,
)
moe_align_block_size_stage2[grid](
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
BLOCK_SIZE=num_experts,
)
moe_align_block_size_stage3[(1,)](
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
BLOCK_SIZE=num_experts,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
BLOCK_SIZE=num_experts,
)


@torch.no_grad()
def torch_moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:

tokens_cnts = topk_ids.new_zeros((topk_ids.shape[0], num_experts))
tokens_cnts.scatter_(1, topk_ids, 1)
tokens_cnts = tokens_cnts.sum(dim=0)
cumsum = topk_ids.new_zeros(num_experts + 1)
for i in range(num_experts):
cumsum[i + 1] = cumsum[i] + ceil_div(tokens_cnts[i], block_size) * block_size
num_tokens_post_pad[0] = cumsum[-1]
expert_index = 0
for i in range(0, num_tokens_post_pad[0], block_size):
while i >= cumsum[expert_index + 1]:
expert_index += 1
experts_ids[i // block_size] = expert_index
numel = topk_ids.numel()
topk_ids = topk_ids.view(-1)
tokens_cnts = torch.zeros_like(tokens_cnts)
for i in range(numel):
expert_id = topk_ids[i]
rank_post_pad = tokens_cnts[expert_id] + cumsum[expert_id]
sorted_token_ids[rank_post_pad] = i
tokens_cnts[expert_id] += 1


def test():
def generate_unique_rows_tensor(rows=59, cols=6, max_val=63):
assert cols <= max_val + 1, "Number of columns cannot be greater than max_val + 1"

tensor = torch.empty((rows, cols), dtype=torch.int64, device="cuda")

for i in range(rows):
row = torch.randperm(max_val + 1, dtype=torch.int64, device="cuda")[:cols]
tensor[i] = row

return tensor

num_experts = 64
topk_ids = generate_unique_rows_tensor(8192, 6, num_experts - 1)
block_size = 16
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

sorted_ids_1 = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device)
sorted_ids_1.fill_(topk_ids.numel())
expert_ids_1 = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device)
num_tokens_post_pad_1 = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

import time

start_time = time.time()
torch_moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)
end_time = time.time()
print("torch cost: ", end_time - start_time)
moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids_1, expert_ids_1, num_tokens_post_pad_1)
end_time1 = time.time()
print("triton cost: ", end_time1 - end_time)
assert torch.equal(sorted_ids, sorted_ids_1)
assert torch.equal(expert_ids, expert_ids_1)
assert torch.equal(num_tokens_post_pad, num_tokens_post_pad_1)


if __name__ == "__main__":
test()
Loading

0 comments on commit 27362fa

Please sign in to comment.