From da572980bc7152b272cd8f57ddc1edff326d184d Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Thu, 4 Jul 2024 13:03:46 +0800 Subject: [PATCH 1/6] Fix the incorrect permutation of gguf --- .../modeling_gguf_pytorch_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 1511fbac0976ac..6a476466ab4ef9 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -147,10 +147,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): num_heads = parsed_parameters["config"]["num_attention_heads"] - tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0]) - weights = weights.reshape(tmp_shape) - weights = weights.transpose(0, 2, 1, 3) - weights = weights.reshape(shape[::-1]) + n_head_kv = parsed_parameters["config"]["num_key_value_heads"] + if ".attn_q." in name: + weights = reverse_hf_permute(weights, num_heads, num_heads) + elif ".attn_k." in name: + weights = reverse_hf_permute(weights, num_heads, n_head_kv) for tensor_name in tensor_key_mapping: if tensor_name in name: @@ -163,3 +164,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}") return parsed_parameters + + +def reverse_hf_permute(weights: np.ndarray, n_head: int, n_head_kv: int) -> np.ndarray: + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + + dim = weights.shape[0] // n_head // 2 + w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) + return w.swapaxes(2, 1).reshape(weights.shape) From 9cfe762266315e158bceace49591842af84bcce6 Mon Sep 17 00:00:00 2001 From: Penut Chen <94501378+PenutChen@users.noreply.github.com> Date: Fri, 5 Jul 2024 08:48:07 +0800 Subject: [PATCH 2/6] rename num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_gguf_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 6a476466ab4ef9..f10dc90454052c 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -147,7 +147,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): num_heads = parsed_parameters["config"]["num_attention_heads"] - n_head_kv = parsed_parameters["config"]["num_key_value_heads"] + num_kv_heads = parsed_parameters["config"]["num_key_value_heads"] if ".attn_q." in name: weights = reverse_hf_permute(weights, num_heads, num_heads) elif ".attn_k." in name: From fe3cd6c3c9b19c865dc6f7739119abac052e727e Mon Sep 17 00:00:00 2001 From: Penut Chen <94501378+PenutChen@users.noreply.github.com> Date: Fri, 5 Jul 2024 08:49:51 +0800 Subject: [PATCH 3/6] add typing to num_kv_heads Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/modeling_gguf_pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index f10dc90454052c..d94a7f3fa8572b 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -166,7 +166,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): return parsed_parameters -def reverse_hf_permute(weights: np.ndarray, n_head: int, n_head_kv: int) -> np.ndarray: +def reverse_hf_permute(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv From 07c4ccf3337ebe5e2f18a682ccc779c4c445ebd1 Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Fri, 5 Jul 2024 09:11:32 +0800 Subject: [PATCH 4/6] rename variables --- src/transformers/modeling_gguf_pytorch_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index d94a7f3fa8572b..630a2d8c3eed12 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import numpy as np from tqdm import tqdm @@ -151,7 +153,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): if ".attn_q." in name: weights = reverse_hf_permute(weights, num_heads, num_heads) elif ".attn_k." in name: - weights = reverse_hf_permute(weights, num_heads, n_head_kv) + weights = reverse_hf_permute(weights, num_heads, num_kv_heads) for tensor_name in tensor_key_mapping: if tensor_name in name: @@ -167,8 +169,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): def reverse_hf_permute(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv + if num_kv_heads is not None and n_head != num_kv_heads: + n_head = num_kv_heads dim = weights.shape[0] // n_head // 2 w = weights.reshape(n_head, dim, 2, *weights.shape[1:]) From d8651d360c47d087454242a0d792e0ea737183cc Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Fri, 5 Jul 2024 09:14:07 +0800 Subject: [PATCH 5/6] refactor permute function name --- src/transformers/modeling_gguf_pytorch_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 630a2d8c3eed12..3cf34eab584e48 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -151,9 +151,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): num_heads = parsed_parameters["config"]["num_attention_heads"] num_kv_heads = parsed_parameters["config"]["num_key_value_heads"] if ".attn_q." in name: - weights = reverse_hf_permute(weights, num_heads, num_heads) + weights = reverse_permute_weights(weights, num_heads, num_heads) elif ".attn_k." in name: - weights = reverse_hf_permute(weights, num_heads, num_kv_heads) + weights = reverse_permute_weights(weights, num_heads, num_kv_heads) for tensor_name in tensor_key_mapping: if tensor_name in name: @@ -168,7 +168,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): return parsed_parameters -def reverse_hf_permute(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: +def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray: + # Original permutation implementation + # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408 if num_kv_heads is not None and n_head != num_kv_heads: n_head = num_kv_heads From 450cd12231aa4f29a67b09fa66307ad44927bd60 Mon Sep 17 00:00:00 2001 From: Penut Chen Date: Fri, 5 Jul 2024 10:13:18 +0800 Subject: [PATCH 6/6] update the expected text of the llama3 q4 test --- tests/quantization/ggml/test_ggml.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index e5e8dbaf36cffb..db96e9052c5f36 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -188,8 +188,7 @@ def test_llama3_q4_0(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello, I am new to this forum. I am" - + EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_tokenization_xnli(self):