Skip to content

Commit

Permalink
[Tools] Add Baichuan1/2 convert tool (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
abenmao committed Jun 17, 2024
1 parent 80df391 commit 2331613
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ xFasterTransformer provides a series of APIs, both of C++ and Python, for end us
| Llama | ✔ | ✔ | ✔ |
| Llama2 | ✔ | ✔ | ✔ |
| Llama3 | ✔ | ✔ | ✔ |
| Baichuan1 | ✔ | ✔ | ✔ |
| Baichuan | ✔ | ✔ | ✔ |
| Baichuan2 | ✔ | ✔ | ✔ |
| QWen | ✔ | ✔ | ✔ |
| QWen2 | ✔ | ✔ | ✔ |
Expand Down Expand Up @@ -174,6 +174,7 @@ xFasterTransformer supports a different model format from Huggingface, but it's
- ChatGLM3Convert
- OPTConvert
- BaichuanConvert
- Baichuan2Convert
- QwenConvert
- Qwen2Convert
- DeepseekConvert
Expand Down
2 changes: 2 additions & 0 deletions src/xfastertransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def with_mpirun():
"ChatGLM3Convert",
"OPTConvert",
"BaichuanConvert",
"Baichuan2Convert",
"QwenConvert",
"Qwen2Convert",
"YaRNLlamaConvert",
Expand All @@ -59,6 +60,7 @@ def with_mpirun():
from .tools import ChatGLM3Convert
from .tools import OPTConvert
from .tools import BaichuanConvert
from .tools import Baichuan2Convert
from .tools import QwenConvert
from .tools import Qwen2Convert
from .tools import YaRNLlamaConvert
Expand Down
1 change: 1 addition & 0 deletions src/xfastertransformer/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .chatglm3_convert import ChatGLM3Convert
from .opt_convert import OPTConvert
from .baichuan_convert import BaichuanConvert
from .baichuan2_convert import Baichuan2Convert
from .qwen_convert import QwenConvert
from .qwen2_convert import Qwen2Convert
from .yarn_llama_convert import YaRNLlamaConvert
29 changes: 29 additions & 0 deletions src/xfastertransformer/tools/baichuan2_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .baichuan_convert import BaichuanConvert

from torch import nn

class Baichuan2Convert(BaichuanConvert):
"""
Convert huggingface Baichuan2 model. Use https://huggingface.co/baichuan-inc
"""

def __init__(self):
super().__init__()

def _head_process(self, param):
# NormHead
return nn.functional.normalize(param)
6 changes: 4 additions & 2 deletions src/xfastertransformer/tools/baichuan_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import multiprocessing
import numpy as np
import os
from torch import nn

from transformers import AutoModelForCausalLM

Expand All @@ -31,6 +30,9 @@ class BaichuanConvert(BaseModelConvert):
def __init__(self):
super().__init__()

def _head_process(self, param):
return param

def split_and_convert_process(self, i, saved_dir, factor, key, val):
def save_val(val, key, tp_num=None):
if key.startswith("model."):
Expand Down Expand Up @@ -148,7 +150,7 @@ def split_and_convert(self, input_dir, output_dir, dtype, processes):
if "embed" in name:
model_named_parameters[name] = param
elif "lm_head" in name:
model_named_parameters[name] = nn.functional.normalize(param)
model_named_parameters[name] = self._head_process(param)
else:
model_named_parameters[name] = param.permute(1, 0) if len(param.shape) == 2 else param

Expand Down

0 comments on commit 2331613

Please sign in to comment.