Skip to content

Commit

Permalink
First prototype, let's jump padding free
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Jul 27, 2024
1 parent 50da275 commit 419b33a
Show file tree
Hide file tree
Showing 10 changed files with 2,141 additions and 11 deletions.
764 changes: 764 additions & 0 deletions convert_hf_nanotron.ipynb

Large diffs are not rendered by default.

97 changes: 97 additions & 0 deletions examples/config_llama_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: /mloscratch/homes/solergib/converter/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
hf_dataset: Open-Orca/SlimOrca
hf_dataset_split: train
conversation_column_name: conversations
train_on_completions_only: true
remove_cross_attention: true
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: Chat
run: Llama3-8B
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 4096
num_attention_heads: 32
num_hidden_layers: 4
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 500000.0
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 3
sequence_length: 4096
train_steps: 100
val_check_interval: -1
30 changes: 28 additions & 2 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.config import ChatDatasetsArgs, DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.chat_dataset import ChatDataset
from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
Expand Down Expand Up @@ -171,6 +172,31 @@ def get_dataloader_from_data_stage(
dataloader_drop_last=True,
)

return train_dataloader
# Case 4: Chat Datasets
elif isinstance(data.dataset, ChatDatasetsArgs):
with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = ChatDataset(
dataset_path=data.dataset.hf_dataset,
tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path,
sequence_length=trainer.sequence_length,
train_on_completions_only=data.dataset.train_on_completions_only,
remove_cross_attention=data.dataset.remove_cross_attention,
split=data.dataset.hf_dataset_split,
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
dp_ranks_size=trainer.parallel_context.dp_pg.size(),
)

# Prepare dataloader
train_dataloader = build_chat_dataloader(
dataset=train_dataset,
sequence_length=trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
Expand Down
18 changes: 17 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,27 @@ def __post_init__(self):
self.dataset_weights = list(tmp_dataset_folder.values())


@dataclass
class ChatDatasetsArgs:
hf_dataset: str
hf_dataset_split: str
conversation_column_name: str
# Debug
train_on_completions_only: bool = True
remove_cross_attention: bool = True

def __post_init__(self):
if self.hf_dataset_split is None:
self.hf_dataset_split = "train"
if self.conversation_column_name is None:
self.conversation_column_name = "conversations"


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
139 changes: 139 additions & 0 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import List

import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from nanotron.data.chat_tokenizer import ChatTokenizer
from nanotron.data.collator import (
build_labels,
build_labels_completions_only,
build_position_ids,
build_position_ids_dummy,
)
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer


class ChatDataset(IterableDataset):
"""
Chat Dataset for training models with:
1. Packing
2. No cross-contamination between packed samples
3. Train on completitions only
Args:
dataset_path (str): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded.
tokenizer_name_or_path (str): Path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.
seq_len (int): max sequence length
train_on_completions_only (bool): Whether to just train on completitions or not. To be deleted
remove_cross_attention (bool): Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism). To be deleted
split (str): Split of the dataset to train on
conversation_column_name (str): Column name of the dataset containing the conversations
dp_rank (int): rank of the current data parallel process
dp_ranks_size (int): number of data parallel processes participating in training
"""

def __init__(
self,
dataset_path: str,
tokenizer_name_or_path,
sequence_length: int,
conversation_column_name: str,
train_on_completions_only: bool = True,
remove_cross_attention: bool = True,
split: str = "train",
dp_rank: int = 0,
dp_ranks_size: int = 1,
skip_num_samples: int = None, # TODO Delete, check later comment
seed: int = 1234,
) -> None:

# TODO: Support checkpointing for resuming training. We have to store the number of consumed samples from the dataset (Which is different from the number of steps) and the buffers.
# skip_num_samples will fail, as it's computed with the number of steps and as we are packing sequences we might have consumed MORE samples from the dataset
# TODO: Support interleaving datasets

self.dataset_path = dataset_path
self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path)
self.sequence_length = sequence_length
self.conversation_column_name = conversation_column_name
self.skip_num_samples = skip_num_samples
self.seed = seed

# Load, split and shuffle dataset. Also skip samples if resuming training.
self.dataset = load_dataset(dataset_path, split=split, streaming=True)
self.dataset = split_dataset_by_node(self.dataset, dp_rank, dp_ranks_size)
self.dataset = self.dataset.shuffle(seed=seed, buffer_size=10_000)

# TODO delete, just 4 switching the training only on completitions setting
if train_on_completions_only:
self.create_labels = build_labels_completions_only
else:
self.create_labels = build_labels

# TODO delete, just 4 switching the remove cross-attention setting
if remove_cross_attention:
self.create_position_ids = build_position_ids
else:
self.create_position_ids = build_position_ids_dummy

# Todo delete (debug), just change the dict keys
self.debug_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # TODO delete debug
self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}"

def __iter__(self):
max_buffer_token_len = 1 + self.sequence_length
buffer_tokens: List[int] = []
buffer_is_completition: List[int] = []
buffer_lengths: List[int] = []

while True:
for sample in iter(self.dataset):
tokens, is_completition = self.chat_tokenizer(sample[self.conversation_column_name])

# TODO assert that tokenized conversations are not longer than max_buffer_token_len?

# TODO delete (debug). The [:-1] of tokens is because apply chat template doesn't adds eos (NOT eot) token
assert (
self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1]
), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}'

buffer_tokens.extend(tokens)
buffer_is_completition.extend(is_completition)
buffer_lengths.append(len(tokens))

if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield
# Pop last sample from buffers
sample_tokens = buffer_tokens[: -len(tokens)]
sample_completitions = buffer_is_completition[: -len(tokens)]
sample_lengths = buffer_lengths[:-1]

# TODO delete (debug)
assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths)

# Reset tokens buffers
buffer_tokens = tokens.copy()
buffer_is_completition = is_completition.copy()
buffer_lengths = [len(tokens)]

# Pad to max_buffer_token_len. Pad token added in ChatTokenizer init if necessary
sample_tokens.extend(
[self.chat_tokenizer.tokenizer.pad_token_id] * (max_buffer_token_len - len(sample_tokens))
)
sample_completitions.extend([False] * (max_buffer_token_len - len(sample_completitions)))

# TODO delete, just 4 switching the training only on completitions setting
labels = self.create_labels(sample_tokens, sample_completitions)

# TODO delete, just 4 switching the remove cross-attention setting
position_ids = self.create_position_ids(sample_lengths, self.sequence_length)

# TODO delete (debug)
assert len(sample_tokens) == max_buffer_token_len

yield {
"input_ids": np.array(sample_tokens[:-1], dtype=np.int32),
"label_ids": labels,
"position_ids": position_ids,
}

print("Consumed all samples, dataset is being re-looped.")
81 changes: 81 additions & 0 deletions src/nanotron/data/chat_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Tuple

from transformers import AutoTokenizer


class ChatTokenizer:
"""
The ChatTokenizer encodes a conversation applying the Llama3 Chat Template and returns the role (Either User or Assistant) of each token
Args:
tokenizer_name_or_path (str): A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.
"""

def __init__(self, tokenizer_name_or_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

# Add pad token if necessary
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"})

def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]:
"""
Applies the Llama3 chat template, encodes the conversation and returns the tokens along with a bool value for each token whether if the token belongs to the answer of the assistant or not to be able to just train on the assistant answers
Args:
conversation (List[dict]): List of dicts where each dict contains the "from" key to specify the emisor del mensaje and the "value" key with the message.
Same format as SlimOrca dataset with possible from values: "System", "human" and "gpt"
Example:
conversation: [ { "from": "system", "value": "You are an AI assistant that follows instruction extremely well. Help as much as you can."},
{ "from": "human", "value": "Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about.\nAnswer:"},
{ "from": "gpt", "value": "The information provided seems to refer to Rian Wallace, a former NFL player."} ]
After applying chat template:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an AI assistant that follows instruction extremely well. Help as much as you can.<|eot_id|><|start_header_id|>human<|end_header_id|>
Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about.
Answer:<|eot_id|><|start_header_id|>gpt<|end_header_id|>
The information provided seems to refer to Rian Wallace, a former NFL player.<|eot_id|>
returns:
tokens (List[int]): A list of tokens e.g. [128000, 128006, 9125, 128007, 271, 2675, 527, ..., 12873, 2851, 13, 128009, 128001]
is_completitions (List[bool]): A list of bools whether the tokens belong to the assistant answer or not e.g. [False, False, False, ..., False, True, True, True, True]
"""
tokens = []
# Append <|begin_of_text|>
tokens.append(self.tokenizer.bos_token_id)
is_completitions = [False] * len(tokens)

for message in conversation:
message_tokens, message_completitions = self.encode_message(message)
tokens.extend(message_tokens)
is_completitions.extend(message_completitions)

# Append <|end_of_text|> token
tokens.extend(self.tokenizer.encode("<|end_of_text|>", add_special_tokens=False))
is_completitions.append(True)

return tokens, is_completitions

def encode_message(self, message: dict) -> Tuple[List[int], List[int]]:
# TODO The "from", "value", "gpt" keys are form SlimOrca Dataset. Llama3 uses another ones. We should stick to a
# single format and document it properly rather than supporting multiple formats, as each one will need a different
# ChatTokenizer and the idea is that all Datasets share the same ChatTokenizer

# Encode header
tokens = self.tokenizer.encode(
f"<|start_header_id|>{message['from']}<|end_header_id|>\n\n", add_special_tokens=False
)
is_completitions = [False] * len(tokens)

# Encode message
tokens.extend(self.tokenizer.encode(message["value"].strip(), add_special_tokens=False))

# Append <|eot_id|> token
tokens.extend(self.tokenizer.encode("<|eot_id|>", add_special_tokens=False))

# True if token belongs to assistant answer, False otherwise
is_completitions.extend([True if message["from"] == "gpt" else False] * (len(tokens) - len(is_completitions)))

return tokens, is_completitions
Loading

0 comments on commit 419b33a

Please sign in to comment.