Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SFT training #14

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/sft.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# LlamaSFT
## Introduction
We have incorporated the ability to perform SFT in nanotron with the following features:
1. Packing multiple samples to fill the sequence length of the model
2. Training on completions only: The model learns from the answers, not from the user prompt & chat templates
3. Removing cross-attention between the multiple samples packed

In the following sections, we will delve into more detail about these features and how we have implemented them.

### Feature 1: Packing
To train the models efficiently, we will pack multiple conversations into the same sample until filling the sequence length. As we are packing multiple sequences and to avoid introducing padding tokens, [we will flatten the batch size](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/trainer.py#L259), so `sequence_length = micro_batch_size * sequence_length` and `micro_batch_size = 1`.
![](sft_feature1.png)

### Feature 2: Training only on completions
Conversations consist of user messages, which are usually questions or inquiries, and the model's responses. The ultimate goal is for the model to improve the quality of its responses, and not so much to learn about user questions or other aspects like the chat template. Therefore, during training, we will compute the loss only with the tokens that belong to the answers produced by the model.

To achieve this, when tokenizing the conversations, we will [store the role of each token](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L59) and create an attention mask that the model will use in the loss computation [[1]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L617), [[2]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L603).
![](sft_feature2.png)

### Feature 3: Removing cross-attention
Finally, as we are packing multiple conversations together, we do not want the tokens of one conversation to attend to those of other conversations.
To do this, we will store the `position_ids` of each token in the sequence length to:
1. Apply the RoPE embeddings correctly to each conversation
2. [Create the attention mask](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L346) needed by [`flash_attn_varlen_func`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L352) to compute the attention without cross-contamination between different conversations
![](sft_feature3.png)

## Internals
### Config file
For SFT, we need to setup the config file as follows:
```yaml
- data:
dataset:
hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered
hf_dataset_split: train
conversation_column_name: conversations
train_on_completions_only: true
remove_cross_attention: true
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
```
The `hf_dataset` should be a dataset from the HuggingFace Hub with the same structure as `Magpie-Align/Magpie-Pro-300K-Filtered`; that is, each conversation will be a list of dictionaries, each with the keys `from` [`gpt`, `human`] and `value`. We can select a split with `hf_dataset_split` and the dataset column with `conversation_column_name`. `train_on_completions_only` & `remove_cross_attention` are to toggle on/off Features 2 and 3, but we will remove them for the final release.

### Iterable Dataset
For SFT training, we have developed a new dataset, [`ChatDataset`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L17), responsible for producing data batches during training. Unlike `Nanosets`, this new `ChatDataset` is an [`IterableDataset`](https://pytorch.org/docs/stable/data.html#iterable-style-datasets). The advantage of this type of dataset is that they do not require preprocessing the data before training as they do it on-the-fly, saving us the preprocessing step and the space occupied by the preprocessed data. The downside is that it is not trivial to recover the state of the DataLoader when restarting training. For this, we are developing a solution based on `torchdata`'s [`StatefulDataLoader`](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) that we will incorporate soon.

For now, we allow splitting the dataset between the different data parallel ranks and plan to support interleaved datasets.

### ChatTokenizer
To apply the chat template, tokenize the conversations, and store the role of each token, we have developed the [`ChatTokenizer`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L6). Based on the one included in [`meta/llama3`](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), [this tokenizer will return](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L92) the `tokens` of the conversation and the list of bools `is_completions` indicating whether the token belongs to the model's responses or not, necessary for Feature 2.

For now, we only support the Llama3 tokenizer along with the official chat template of this model.

### Recover DataLoader
Pending development
Binary file added docs/sft_feature1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sft_feature2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sft_feature3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions examples/config_llama8b_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered
hf_dataset_split: train
conversation_column_name: conversations
train_on_completions_only: true
remove_cross_attention: true
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: SFT-Todi
run: Llama3-8B
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 131072
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rope_interleaved: false
rope_theta: 500000.0
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 2.0e-5
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 4
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 4
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 4
sequence_length: 4096
train_steps: 250
val_check_interval: -1
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"safetensors",
"dacite",
"tqdm",
"wandb",
]

[tool.setuptools.packages.find]
Expand Down
29 changes: 27 additions & 2 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.config import ChatDatasetsArgs, DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.chat_dataset import ChatDataset
from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
Expand Down Expand Up @@ -171,6 +172,30 @@ def get_dataloader_from_data_stage(
dataloader_drop_last=True,
)

return train_dataloader
# Case 4: Chat Datasets
elif isinstance(data.dataset, ChatDatasetsArgs):
with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = ChatDataset(
dataset_path=data.dataset.hf_dataset,
tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path,
sequence_length=trainer.sequence_length,
train_on_completions_only=data.dataset.train_on_completions_only,
remove_cross_attention=data.dataset.remove_cross_attention,
split=data.dataset.hf_dataset_split,
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
dp_ranks_size=trainer.parallel_context.dp_pg.size(),
)

# Prepare dataloader
train_dataloader = build_chat_dataloader(
dataset=train_dataset,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
Expand Down
18 changes: 17 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,27 @@ def __post_init__(self):
self.dataset_weights = list(tmp_dataset_folder.values())


@dataclass
class ChatDatasetsArgs:
hf_dataset: str
hf_dataset_split: str
conversation_column_name: str
# Debug
train_on_completions_only: bool = True
remove_cross_attention: bool = True

def __post_init__(self):
if self.hf_dataset_split is None:
self.hf_dataset_split = "train"
if self.conversation_column_name is None:
self.conversation_column_name = "conversations"


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
134 changes: 134 additions & 0 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import List

import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from nanotron.data.chat_tokenizer import ChatTokenizer
from nanotron.data.collator import (
build_labels,
build_labels_completions_only,
build_position_ids,
build_position_ids_dummy,
)
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer


class ChatDataset(IterableDataset):
"""
Chat Dataset for training models with:
1. Padding-Free Packing
2. No cross-contamination between packed samples
3. Train on completitions only

Args:
dataset_path (str): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded.
tokenizer_name_or_path (str): Path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.
seq_len (int): max sequence length
train_on_completions_only (bool): Whether to just train on completitions or not. To be deleted
remove_cross_attention (bool): Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism). To be deleted
split (str): Split of the dataset to train on
conversation_column_name (str): Column name of the dataset containing the conversations
dp_rank (int): rank of the current data parallel process
dp_ranks_size (int): number of data parallel processes participating in training
"""

def __init__(
self,
dataset_path: str,
tokenizer_name_or_path,
sequence_length: int,
conversation_column_name: str,
train_on_completions_only: bool = True,
remove_cross_attention: bool = True,
split: str = "train",
dp_rank: int = 0,
dp_ranks_size: int = 1,
skip_num_samples: int = None, # TODO(tj.solergibert) Delete, check later comment
seed: int = 1234,
) -> None:

# WARN(tj.solergibert) Currently we DON'T support recovering training from a interruption. Check the following TODOs
# TODO(tj.solergibert) Support checkpointing for resuming training. We have to store the number of consumed samples from the dataset (Which is different from the number of steps) and the BUFFERS.
# skip_num_samples will fail, as it's computed with the number of steps and as we are packing sequences we might have consumed MORE samples from the dataset
# TODO(tj.solergibert) Support interleaving datasets

self.dataset_path = dataset_path
self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path)
self.sequence_length = sequence_length
self.conversation_column_name = conversation_column_name
self.skip_num_samples = skip_num_samples
self.seed = seed

# Load, split and shuffle dataset
self.dataset = load_dataset(dataset_path, split=split, streaming=True)
self.dataset = split_dataset_by_node(self.dataset, dp_rank, dp_ranks_size)
self.dataset = self.dataset.shuffle(seed=seed, buffer_size=10_000)

# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting
if train_on_completions_only:
self.create_labels = build_labels_completions_only
else:
self.create_labels = build_labels

# TODO Delete (debug), just 4 switching the remove cross-attention setting
if remove_cross_attention:
self.create_position_ids = build_position_ids
else:
self.create_position_ids = build_position_ids_dummy

# TODO(tj.solergibert) Delete (debug)
self.debug_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # TODO delete debug
self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}"

def __iter__(self):
max_buffer_token_len = 1 + self.sequence_length
buffer_tokens: List[int] = []
buffer_is_completition: List[int] = []
buffer_lengths: List[int] = []

while True:
for sample in iter(self.dataset):
tokens, is_completition = self.chat_tokenizer(sample[self.conversation_column_name])

# TODO(tj.solergibert) Delete (debug). Check if HF apply_chat_template produces the same result as ChatTokenizer
# The [:-1] of tokens is because apply chat template doesn't adds eos (NOT eot) token
assert (
self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1]
), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}'

buffer_tokens.extend(tokens)
buffer_is_completition.extend(is_completition)
buffer_lengths.append(len(tokens))

if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield
# Pop last sample from buffers
sample_tokens = buffer_tokens[: -len(tokens)]
sample_completitions = buffer_is_completition[: -len(tokens)]
sample_lengths = buffer_lengths[:-1]

# TODO(tj.solergibert) Delete (debug)
assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths)

# Reset tokens buffers
buffer_tokens = tokens.copy()
buffer_is_completition = is_completition.copy()
buffer_lengths = [len(tokens)]

# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting
sample_completitions = self.create_labels(sample_tokens, sample_completitions)

# TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting
position_ids = self.create_position_ids(sample_lengths, self.sequence_length)

# TODO(tj.solergibert) Delete (debug)
# assert len(sample_tokens) <= max_buffer_token_len

yield {
"input_ids": np.array(sample_tokens, dtype=np.int32),
"is_completitions": np.array(sample_completitions, dtype=np.bool_),
"position_ids": position_ids,
}

# TODO(tj.solergibert) Change for log_rank (log_rank is problematic with JupyterNB)
print("Consumed all samples, dataset is being re-looped.")
Loading
Loading