Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the incorrect permutation of gguf #31788

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):

if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
weights = weights.reshape(tmp_shape)
weights = weights.transpose(0, 2, 1, 3)
weights = weights.reshape(shape[::-1])
n_head_kv = parsed_parameters["config"]["num_key_value_heads"]
PenutChen marked this conversation as resolved.
Show resolved Hide resolved
if ".attn_q." in name:
weights = reverse_hf_permute(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = reverse_hf_permute(weights, num_heads, n_head_kv)

for tensor_name in tensor_key_mapping:
if tensor_name in name:
Expand All @@ -163,3 +164,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")

return parsed_parameters


def reverse_hf_permute(weights: np.ndarray, n_head: int, n_head_kv: int) -> np.ndarray:
PenutChen marked this conversation as resolved.
Show resolved Hide resolved
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv

dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved