Skip to content

Commit

Permalink
Merge branch 'main' into llama3_converter
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert authored Jul 25, 2024
2 parents 0afd7b7 + 5f82f7a commit 43fe9b9
Show file tree
Hide file tree
Showing 20 changed files with 1,050 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/fa2_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Instal nanotron
- name: Install nanotron
run: |
python -m pip install --upgrade pip
pip install packaging
Expand All @@ -55,4 +55,4 @@ jobs:
- name: Run tests
# NOTE: -m fa2 will only run the unit tests that have the mark
# "fa2" (these are FA2-related tests)
run: pytest -m fa2 --color=yes --durations=0 --ignore tests/fp8 --verbose tests/
run: pytest -m fa2 --color=yes --durations=0 --ignore tests/fp8 --ignore tests/nanoset --verbose tests/
15 changes: 15 additions & 0 deletions .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
on:
push:

name: Secret Leaks

jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ test:
--ignore tests/fp8 \
--verbose \
examples/doremi/tests/

pip install -r examples/llama/requirements.txt
pytest \
--color=yes \
--verbose \
examples/llama/tests/
4 changes: 4 additions & 0 deletions examples/doremi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model
- 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights

and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi

#### Thoughts

For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant.
17 changes: 17 additions & 0 deletions examples/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## Debugging the tests with vscode

To debug the tests with vscode, add the following json to your `launch.json` file.

```
{
"name": "Test conversion",
"type": "python",
"request": "launch",
"module": "pytest",
"console": "integratedTerminal",
"args": [
"examples/llama/tests"
],
"justMyCode": false
}
```
Empty file added examples/llama/__init__.py
Empty file.
119 changes: 119 additions & 0 deletions examples/llama/convert_hf_to_nanotron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Converts a HF model to nanotron format
Command:
torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --checkpoint_path=hf_weights --save_path=nanotron_weights
"""

import dataclasses
import json
from argparse import ArgumentParser
from pathlib import Path

import nanotron
import torch
from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model
from nanotron.config import LlamaConfig as NanotronLlamaConfig
from nanotron.models.llama import LlamaForTraining
from transformers import LlamaConfig as HFLlamaConfig
from transformers import LlamaForCausalLM


def _handle_attention_block(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, n_q_heads: int, n_kv_heads: int, d_qk: int
) -> torch.Tensor:
# Huggingface Llama separates the q, k, v weights (as opposed to nanotron).
# Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even
# and odd dimensions GPT-J style, while the huggingface implementation expects
# the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information
# see flash_attn.layers.rotary.RotaryEmbedding).
# This function handles the concatenation of the q, k, v weights and proper permutation
# to ensure correct transformation.

def interleave(w: torch.Tensor):
w_new = []
for head_w in w.split(d_qk):
head_w = head_w.view(2, d_qk // 2, -1).transpose(0, 1).reshape(d_qk, -1)
w_new.append(head_w)
return torch.cat(w_new)

q = interleave(q)
k = interleave(k)
return torch.cat([q, k, v])


def convert_hf_to_nt(model_hf: LlamaForCausalLM, model_nt: LlamaForTraining, config: NanotronLlamaConfig):
"""Converts the weights from the model_hf to model_nt, making modifications
in-place."""

hf_sd = model_hf.state_dict()
nt_to_hf = get_weight_mapping(config, nt_to_hf=True)

for module_name_nt, module_nt in model_nt.named_modules():
for param_name_nt, param_nt in module_nt.named_parameters(recurse=False):
# In the case of qkv_proj, the nt_to_hf has exactly three keys, ccorresponding
# to q, k, v.
if "qkv_proj" in module_name_nt:
key_k, key_q, key_v = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"])
q = hf_sd[key_q]
k = hf_sd[key_k]
v = hf_sd[key_v]
param = _handle_attention_block(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
)
# The case of gate_up_proj, nt_to_hf_map has two keys.
elif "gate_up_proj" in module_name_nt:
key_gate, key_up = sorted(nt_to_hf[f"{module_name_nt}.{param_name_nt}"])
gate = hf_sd[key_gate]
up = hf_sd[key_up]
param = torch.cat([gate, up])
# All other cases are simple 1-to-1 correspondence.
else:
hf_key = nt_to_hf[f"{module_name_nt}.{param_name_nt}"]
param = hf_sd[hf_key]

with torch.no_grad():
param_nt.copy_(param)


def get_nanotron_config(config: HFLlamaConfig) -> NanotronLlamaConfig:
"""Converts a huggingface configuration to nanotron configuration."""
attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=True).items()}
return NanotronLlamaConfig(**attrs)


def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path):
"""Loads the huggingface checkpoint in `checkpoint_path`, creates
a new nanotron instance, copies the weights from the huggingface checkpoint
and saves the transformed nanotron to `save_path`."""

# Load huggingface.
hf_model = LlamaForCausalLM.from_pretrained(checkpoint_path)

# Init nanotron model.
model_config = get_nanotron_config(hf_model.config)
nanotron_model = load_nanotron_model(model_config=model_config)

# Copy weights and save model.
parallel_context = nanotron.parallel.ParallelContext(
data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=1
)
convert_hf_to_nt(hf_model, nanotron_model, model_config)
nanotron.serialize.save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path)
with open(save_path / "model_config.json", "w+") as f:
json.dump(dataclasses.asdict(model_config), f)
print(f"Model saved to {save_path}")


if __name__ == "__main__":
parser = ArgumentParser(description="Convert HF weights to nanotron format")
parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint")
parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the nanotron model")
args = parser.parse_args()

# Convert HF model to nanotron format.
convert_checkpoint_and_save(checkpoint_path=args.checkpoint_path, save_path=args.save_path)
154 changes: 154 additions & 0 deletions examples/llama/convert_nanotron_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Converts a nanotron model to HF format
Command:
torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron-path --save_path=hf-path
"""

import json
from argparse import ArgumentParser
from pathlib import Path
from typing import Literal, Optional

import torch
from convert_weights import get_config_mapping, get_weight_mapping, load_nanotron_model
from nanotron.config import LlamaConfig as NanotronLlamaConfig
from nanotron.models import init_on_device_and_dtype
from nanotron.models.llama import LlamaForTraining
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import LlamaConfig as HFLlamaConfig

TEST_PROMPT = "What is the meaning of the word chutzpah?\nThe word chutzpah means"


def _handle_attention_block(
qkv: torch.Tensor, part: Literal["q", "k", "v"], n_q_heads: int, n_kv_heads: int, d_qk: int
) -> torch.Tensor:
# Huggingface Llama separates the q, k, v weights (as opposed to nanotron).
# Furthermore, in the rotary embeddings in nanotron expects interleaved pairs of even
# and odd dimensions GPT-J style, while the huggingface implementation expects
# the whole 1st half and then the whole 2nd half GPT-NeoX style (for more information
# see flash_attn.layers.rotary.RotaryEmbedding).
# This function selects the proper chunk of the bundled qkv tensor and permutation
# to ensure correct transformation to huggingface.

def interleave(w: torch.Tensor):
w_new = []
for head_w in w.split(d_qk):
head_w = head_w.view(d_qk // 2, 2, -1).transpose(0, 1).reshape(d_qk, -1)
w_new.append(head_w)
return torch.cat(w_new)

assert part in ["q", "k", "v"], "part must be one of [q, k, v]"

index_end_q = n_q_heads * d_qk
index_end_k = index_end_q + n_kv_heads * d_qk
if part == "q":
return interleave(qkv[:index_end_q])
if part == "k":
return interleave(qkv[index_end_q:index_end_k])
return qkv[index_end_k:]


def _handle_gate_up_proj(gate_up_proj: torch.Tensor, gate: bool) -> torch.Tensor:
# The gate and up projection are bundled in nanotron.
# This function selects the proper chunk in the bundled weights to return
# either the gate or the up projection only.
weight_size = gate_up_proj.shape[0] // 2
if gate:
return gate_up_proj[:weight_size]
else:
return gate_up_proj[weight_size:]


def convert_nt_to_hf(nanotron_model: LlamaForTraining, hf_model: LlamaForCausalLM, model_config: NanotronLlamaConfig):
"""Converts the weights from the nanotron_model to hf_model, making modifications
in-place."""

nanotron_model_state_dict = nanotron_model.state_dict()

hf_to_nt = get_weight_mapping(model_config, nt_to_hf=False)
for module_name_hf, module_hf in hf_model.named_modules():
for param_name_hf, param_hf in module_hf.named_parameters(recurse=False):
# Get the Nanotron parameter
nanotron_key = hf_to_nt[f"{module_name_hf}.{param_name_hf}"]
param = nanotron_model_state_dict[nanotron_key]

if "qkv_proj" in nanotron_key:
proj_name = module_name_hf.split(".")[4][0]
param = _handle_attention_block(
param,
proj_name,
model_config.num_attention_heads,
model_config.num_key_value_heads,
model_config.hidden_size // model_config.num_attention_heads,
)

elif "gate_up_proj" in nanotron_key:
gate = "gate" in module_name_hf
param = _handle_gate_up_proj(param, gate)

with torch.no_grad():
param_hf.copy_(param)


def get_hf_config(config: NanotronLlamaConfig) -> HFLlamaConfig:
"""Converts a nanotron configuration to huggingface configuration."""
attrs = {key: getattr(config, value) for key, value in get_config_mapping(nt_to_hf=False).items()}
return HFLlamaConfig(**attrs)


def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str] = None):
"""Loads the nanotron checkpoint in `checkpoint_path`, creates
a new huggingface instance, copies the weights from the nanotron checkpoint
and saves the transformed huggingface to `save_path`."""

# Init nanotron model.
with open(checkpoint_path / "model_config.json", "r") as f:
attrs = json.load(f)
model_config = NanotronLlamaConfig(**attrs)
nanotron_model = load_nanotron_model(
model_config=model_config,
checkpoint_path=checkpoint_path,
)
# Init huggingface model.
with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16):
model_config_hf = get_hf_config(model_config)
hf_model = LlamaForCausalLM._from_config(model_config_hf)

# Copy weights, initialize tokenizer and save model.
if tokenizer_name is not None:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.save_pretrained(save_path)
convert_nt_to_hf(nanotron_model, hf_model, model_config)
hf_model.save_pretrained(save_path)
print(f"Model saved to {save_path}")


def check_converted_model_generation(save_path: Path):
"""Loads a huggingface model and tokenizer from `save_path` and
performs a dummy text generation."""

tokenizer = AutoTokenizer.from_pretrained(save_path)
input_ids = tokenizer(TEST_PROMPT, return_tensors="pt")["input_ids"].cuda()
print("Inputs:", tokenizer.batch_decode(input_ids))

model = LlamaForCausalLM.from_pretrained(save_path).cuda().bfloat16()
out = model.generate(input_ids, max_new_tokens=100)
print("Generation (converted): ", tokenizer.batch_decode(out))


if __name__ == "__main__":
parser = ArgumentParser(description="Convert Nanotron weights to HF format")
parser.add_argument("--checkpoint_path", type=Path, default="llama-7b", help="Path to the checkpoint")
parser.add_argument("--save_path", type=Path, default="llama-7b-hf", help="Path to save the HF model")
parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Llama-2-7b-chat-hf")
args = parser.parse_args()

# Convert Nanotron model to HF format.
convert_checkpoint_and_save(
checkpoint_path=args.checkpoint_path, save_path=args.save_path, tokenizer_name=args.tokenizer_name
)

# Check if the conversion was successful by generating some text.
if args.tokenizer_name is not None:
check_converted_model_generation(save_path=args.save_path)
Loading

0 comments on commit 43fe9b9

Please sign in to comment.