diff --git a/docs/sft.md b/docs/sft.md new file mode 100644 index 00000000..e3d104c6 --- /dev/null +++ b/docs/sft.md @@ -0,0 +1,56 @@ +# LlamaSFT +## Introduction +We have incorporated the ability to perform SFT in nanotron with the following features: +1. Packing multiple samples to fill the sequence length of the model +2. Training on completions only: The model learns from the answers, not from the user prompt & chat templates +3. Removing cross-attention between the multiple samples packed + +In the following sections, we will delve into more detail about these features and how we have implemented them. + +### Feature 1: Packing +To train the models efficiently, we will pack multiple conversations into the same sample until filling the sequence length. As we are packing multiple sequences and to avoid introducing padding tokens, [we will flatten the batch size](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/trainer.py#L259), so `sequence_length = micro_batch_size * sequence_length` and `micro_batch_size = 1`. +![](sft_feature1.png) + +### Feature 2: Training only on completions +Conversations consist of user messages, which are usually questions or inquiries, and the model's responses. The ultimate goal is for the model to improve the quality of its responses, and not so much to learn about user questions or other aspects like the chat template. Therefore, during training, we will compute the loss only with the tokens that belong to the answers produced by the model. + +To achieve this, when tokenizing the conversations, we will [store the role of each token](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L59) and create an attention mask that the model will use in the loss computation [[1]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L617), [[2]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L603). +![](sft_feature2.png) + +### Feature 3: Removing cross-attention +Finally, as we are packing multiple conversations together, we do not want the tokens of one conversation to attend to those of other conversations. +To do this, we will store the `position_ids` of each token in the sequence length to: +1. Apply the RoPE embeddings correctly to each conversation +2. [Create the attention mask](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L346) needed by [`flash_attn_varlen_func`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L352) to compute the attention without cross-contamination between different conversations +![](sft_feature3.png) + +## Internals +### Config file +For SFT, we need to setup the config file as follows: +```yaml +- data: + dataset: + hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered + hf_dataset_split: train + conversation_column_name: conversations + train_on_completions_only: true + remove_cross_attention: true + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +``` +The `hf_dataset` should be a dataset from the HuggingFace Hub with the same structure as `Magpie-Align/Magpie-Pro-300K-Filtered`; that is, each conversation will be a list of dictionaries, each with the keys `from` [`gpt`, `human`] and `value`. We can select a split with `hf_dataset_split` and the dataset column with `conversation_column_name`. `train_on_completions_only` & `remove_cross_attention` are to toggle on/off Features 2 and 3, but we will remove them for the final release. + +### Iterable Dataset +For SFT training, we have developed a new dataset, [`ChatDataset`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L17), responsible for producing data batches during training. Unlike `Nanosets`, this new `ChatDataset` is an [`IterableDataset`](https://pytorch.org/docs/stable/data.html#iterable-style-datasets). The advantage of this type of dataset is that they do not require preprocessing the data before training as they do it on-the-fly, saving us the preprocessing step and the space occupied by the preprocessed data. The downside is that it is not trivial to recover the state of the DataLoader when restarting training. For this, we are developing a solution based on `torchdata`'s [`StatefulDataLoader`](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) that we will incorporate soon. + +For now, we allow splitting the dataset between the different data parallel ranks and plan to support interleaved datasets. + +### ChatTokenizer +To apply the chat template, tokenize the conversations, and store the role of each token, we have developed the [`ChatTokenizer`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L6). Based on the one included in [`meta/llama3`](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), [this tokenizer will return](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L92) the `tokens` of the conversation and the list of bools `is_completions` indicating whether the token belongs to the model's responses or not, necessary for Feature 2. + +For now, we only support the Llama3 tokenizer along with the official chat template of this model. + +### Recover DataLoader +Pending development diff --git a/docs/sft_feature1.png b/docs/sft_feature1.png new file mode 100644 index 00000000..162322f0 Binary files /dev/null and b/docs/sft_feature1.png differ diff --git a/docs/sft_feature2.png b/docs/sft_feature2.png new file mode 100644 index 00000000..2dbd9803 Binary files /dev/null and b/docs/sft_feature2.png differ diff --git a/docs/sft_feature3.png b/docs/sft_feature3.png new file mode 100644 index 00000000..029a4639 Binary files /dev/null and b/docs/sft_feature3.png differ diff --git a/examples/config_llama8b_sft.yaml b/examples/config_llama8b_sft.yaml new file mode 100644 index 00000000..cf6e2db7 --- /dev/null +++ b/examples/config_llama8b_sft.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered + hf_dataset_split: train + conversation_column_name: conversations + train_on_completions_only: true + remove_cross_attention: true + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: SFT-Todi + run: Llama3-8B + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 131072 + num_hidden_layers: 32 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: false + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 2.0e-5 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 4 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 4 + sequence_length: 4096 + train_steps: 250 + val_check_interval: -1 diff --git a/pyproject.toml b/pyproject.toml index 6a0cfb83..4810a60a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "safetensors", "dacite", "tqdm", + "wandb", ] [tool.setuptools.packages.find] diff --git a/run_train.py b/run_train.py index 021d955d..ae89365c 100644 --- a/run_train.py +++ b/run_train.py @@ -12,8 +12,9 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs -from nanotron.data.dataloader_builder import build_nanoset_dataloader +from nanotron.config import ChatDatasetsArgs, DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.data.chat_dataset import ChatDataset +from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader from nanotron.dataloader import ( clm_process, dummy_infinite_data_generator, @@ -171,6 +172,30 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: Chat Datasets + elif isinstance(data.dataset, ChatDatasetsArgs): + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = ChatDataset( + dataset_path=data.dataset.hf_dataset, + tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path, + sequence_length=trainer.sequence_length, + train_on_completions_only=data.dataset.train_on_completions_only, + remove_cross_attention=data.dataset.remove_cross_attention, + split=data.dataset.hf_dataset_split, + conversation_column_name=data.dataset.conversation_column_name, + dp_rank=trainer.parallel_context.dp_pg.rank(), + dp_ranks_size=trainer.parallel_context.dp_pg.size(), + ) + + # Prepare dataloader + train_dataloader = build_chat_dataloader( + dataset=train_dataset, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..96337e9a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,11 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class ChatDatasetsArgs: + hf_dataset: str + hf_dataset_split: str + conversation_column_name: str + # Debug + train_on_completions_only: bool = True + remove_cross_attention: bool = True + + def __post_init__(self): + if self.hf_dataset_split is None: + self.hf_dataset_split = "train" + if self.conversation_column_name is None: + self.conversation_column_name = "conversations" + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/chat_dataset.py b/src/nanotron/data/chat_dataset.py new file mode 100644 index 00000000..79ec9be5 --- /dev/null +++ b/src/nanotron/data/chat_dataset.py @@ -0,0 +1,134 @@ +from typing import List + +import numpy as np +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +from nanotron.data.chat_tokenizer import ChatTokenizer +from nanotron.data.collator import ( + build_labels, + build_labels_completions_only, + build_position_ids, + build_position_ids_dummy, +) +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + + +class ChatDataset(IterableDataset): + """ + Chat Dataset for training models with: + 1. Padding-Free Packing + 2. No cross-contamination between packed samples + 3. Train on completitions only + + Args: + dataset_path (str): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded. + tokenizer_name_or_path (str): Path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub. + seq_len (int): max sequence length + train_on_completions_only (bool): Whether to just train on completitions or not. To be deleted + remove_cross_attention (bool): Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism). To be deleted + split (str): Split of the dataset to train on + conversation_column_name (str): Column name of the dataset containing the conversations + dp_rank (int): rank of the current data parallel process + dp_ranks_size (int): number of data parallel processes participating in training + """ + + def __init__( + self, + dataset_path: str, + tokenizer_name_or_path, + sequence_length: int, + conversation_column_name: str, + train_on_completions_only: bool = True, + remove_cross_attention: bool = True, + split: str = "train", + dp_rank: int = 0, + dp_ranks_size: int = 1, + skip_num_samples: int = None, # TODO(tj.solergibert) Delete, check later comment + seed: int = 1234, + ) -> None: + + # WARN(tj.solergibert) Currently we DON'T support recovering training from a interruption. Check the following TODOs + # TODO(tj.solergibert) Support checkpointing for resuming training. We have to store the number of consumed samples from the dataset (Which is different from the number of steps) and the BUFFERS. + # skip_num_samples will fail, as it's computed with the number of steps and as we are packing sequences we might have consumed MORE samples from the dataset + # TODO(tj.solergibert) Support interleaving datasets + + self.dataset_path = dataset_path + self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path) + self.sequence_length = sequence_length + self.conversation_column_name = conversation_column_name + self.skip_num_samples = skip_num_samples + self.seed = seed + + # Load, split and shuffle dataset + self.dataset = load_dataset(dataset_path, split=split, streaming=True) + self.dataset = split_dataset_by_node(self.dataset, dp_rank, dp_ranks_size) + self.dataset = self.dataset.shuffle(seed=seed, buffer_size=10_000) + + # TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting + if train_on_completions_only: + self.create_labels = build_labels_completions_only + else: + self.create_labels = build_labels + + # TODO Delete (debug), just 4 switching the remove cross-attention setting + if remove_cross_attention: + self.create_position_ids = build_position_ids + else: + self.create_position_ids = build_position_ids_dummy + + # TODO(tj.solergibert) Delete (debug) + self.debug_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # TODO delete debug + self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}" + + def __iter__(self): + max_buffer_token_len = 1 + self.sequence_length + buffer_tokens: List[int] = [] + buffer_is_completition: List[int] = [] + buffer_lengths: List[int] = [] + + while True: + for sample in iter(self.dataset): + tokens, is_completition = self.chat_tokenizer(sample[self.conversation_column_name]) + + # TODO(tj.solergibert) Delete (debug). Check if HF apply_chat_template produces the same result as ChatTokenizer + # The [:-1] of tokens is because apply chat template doesn't adds eos (NOT eot) token + assert ( + self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1] + ), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}' + + buffer_tokens.extend(tokens) + buffer_is_completition.extend(is_completition) + buffer_lengths.append(len(tokens)) + + if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield + # Pop last sample from buffers + sample_tokens = buffer_tokens[: -len(tokens)] + sample_completitions = buffer_is_completition[: -len(tokens)] + sample_lengths = buffer_lengths[:-1] + + # TODO(tj.solergibert) Delete (debug) + assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths) + + # Reset tokens buffers + buffer_tokens = tokens.copy() + buffer_is_completition = is_completition.copy() + buffer_lengths = [len(tokens)] + + # TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting + sample_completitions = self.create_labels(sample_tokens, sample_completitions) + + # TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting + position_ids = self.create_position_ids(sample_lengths, self.sequence_length) + + # TODO(tj.solergibert) Delete (debug) + # assert len(sample_tokens) <= max_buffer_token_len + + yield { + "input_ids": np.array(sample_tokens, dtype=np.int32), + "is_completitions": np.array(sample_completitions, dtype=np.bool_), + "position_ids": position_ids, + } + + # TODO(tj.solergibert) Change for log_rank (log_rank is problematic with JupyterNB) + print("Consumed all samples, dataset is being re-looped.") diff --git a/src/nanotron/data/chat_tokenizer.py b/src/nanotron/data/chat_tokenizer.py new file mode 100644 index 00000000..c3252925 --- /dev/null +++ b/src/nanotron/data/chat_tokenizer.py @@ -0,0 +1,83 @@ +from typing import List, Tuple + +from transformers import AutoTokenizer + + +class ChatTokenizer: + """ + The ChatTokenizer encodes a conversation applying the Llama3 Chat Template and returns the role (Either User or Assistant) of each token + + Args: + tokenizer_name_or_path (str): A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub. + """ + + def __init__(self, tokenizer_name_or_path: str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + # Add pad token if necessary + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"}) + + def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]: + """ + Applies the Llama3 chat template, encodes the conversation and returns the tokens along with a bool value for each token whether if the token belongs to the answer of the assistant or not to be able to just train on the assistant answers + Args: + conversation (List[dict]): List of dicts where each dict contains the "from" key to specify the emisor del mensaje and the "value" key with the message. + Same format as SlimOrca dataset with possible from values: "System", "human" and "gpt" + Example: + conversation: [ { "from": "system", "value": "You are an AI assistant that follows instruction extremely well. Help as much as you can."}, + { "from": "human", "value": "Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about.\nAnswer:"}, + { "from": "gpt", "value": "The information provided seems to refer to Rian Wallace, a former NFL player."} ] + + After applying chat template: + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are an AI assistant that follows instruction extremely well. Help as much as you can.<|eot_id|><|start_header_id|>human<|end_header_id|> + + Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about. + Answer:<|eot_id|><|start_header_id|>gpt<|end_header_id|> + + The information provided seems to refer to Rian Wallace, a former NFL player.<|eot_id|> + returns: + tokens (List[int]): A list of tokens e.g. [128000, 128006, 9125, 128007, 271, 2675, 527, ..., 12873, 2851, 13, 128009, 128001] + is_completitions (List[bool]): A list of bools whether the tokens belong to the assistant answer or not e.g. [False, False, False, ..., False, True, True, True, True] + """ + tokens = [] + # Append <|begin_of_text|> + tokens.append(self.tokenizer.bos_token_id) + is_completitions = [False] * len(tokens) + + for message in conversation: + message_tokens, message_completitions = self.encode_message(message) + tokens.extend(message_tokens) + is_completitions.extend(message_completitions) + + # Append <|end_of_text|> token + tokens.extend(self.tokenizer.encode("<|end_of_text|>", add_special_tokens=False)) + is_completitions.append( + False + ) # NOTE(tj.solergibert) No need to predict <|end_of_text|> token from <|eot_id|> token + + return tokens, is_completitions + + def encode_message(self, message: dict) -> Tuple[List[int], List[int]]: + # NOTE(tj.solergibert) The "from", "value", "gpt" keys are from SlimOrca Dataset. Llama3 HF Pretrained tokenizer uses another ones. We should stick to a + # single format and document it properly rather than supporting multiple formats, as each DATASET will need a different + # ChatTokenizer and the idea is that all Datasets share the same ChatTokenizer + + # Encode header + tokens = self.tokenizer.encode( + f"<|start_header_id|>{message['from']}<|end_header_id|>\n\n", add_special_tokens=False + ) + is_completitions = [False] * len(tokens) + + # Encode message + tokens.extend(self.tokenizer.encode(message["value"].strip(), add_special_tokens=False)) + + # Append <|eot_id|> token + tokens.extend(self.tokenizer.encode("<|eot_id|>", add_special_tokens=False)) + + # True if token belongs to assistant answer, False otherwise + is_completitions.extend([True if message["from"] == "gpt" else False] * (len(tokens) - len(is_completitions))) + + return tokens, is_completitions diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..ea68b8b8 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -1,4 +1,4 @@ -import dataclasses +from dataclasses import dataclass from typing import Dict, List, Union import numpy as np @@ -8,7 +8,7 @@ from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer -@dataclasses.dataclass +@dataclass class NanosetDataCollatorForCLM: """ Data collator used for causal language modeling with Nanosets dataset. @@ -78,3 +78,79 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +# TODO(tj.solergibert) After "Beta", delete all the functs except `build_position_ids` and move `build_position_ids` to chat_dataset.py +def build_position_ids(lengths, sequence_length) -> np.array: + position_ids = [list(range(length)) for length in lengths] # Create position ids list + return np.array([x for xs in position_ids for x in xs], dtype=np.int32) # Flatten list of position ids + + +# TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting +def build_position_ids_dummy(lengths, sequence_length) -> np.array: + return np.array(list(range(sum(lengths))), dtype=np.int32) # TODO numpy arange + + +# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting. +def build_labels_completions_only(input_ids, is_completitions): + return is_completitions + + +# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting +def build_labels(input_ids, is_completitions): + return [True for _ in range(len(is_completitions))] + + +@dataclass +class DataCollatorForSFT: + """ + Data collator used with Chat Dataset. + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "position_ids": TensorPointer(group_rank=self.input_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + is_completitions = np.vstack([examples[i]["is_completitions"] for i in range(len(examples))]) # (b, s) + position_ids = np.vstack([examples[i]["position_ids"] for i in range(len(examples))]) # (b, s) + + result: Dict[str, Union[np.ndarray, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["position_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + # Process inputs + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["position_ids"] = position_ids[:, :-1] + + # Process labels: shift them to the left. + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = is_completitions[:, 1:] + + # Cast np.array to torch.Tensor + result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..2136cfcc 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import DataCollatorForSFT, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -62,3 +62,34 @@ def build_nanoset_dataloader( pin_memory=dataloader_pin_memory, worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank), ) + + +def build_chat_dataloader( + dataset, + parallel_context: ParallelContext, + input_pp_rank: int, + output_pp_rank: int, + dataloader_pin_memory: bool = True, +) -> DataLoader: + + # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job + if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]: + dataset_length = 1_000_000 # len(dataset) TODO find a more elegant way to specify this dummy dataset + dataset = EmptyInfiniteDataset(length=dataset_length) + + data_collator = DataCollatorForSFT( + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + + dp_rank = parallel_context.dp_pg.rank() + + return DataLoader( + dataset, + batch_size=1, + collate_fn=data_collator, + num_workers=0, + pin_memory=dataloader_pin_memory, + worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank), + ) diff --git a/src/nanotron/models/llama_sft.py b/src/nanotron/models/llama_sft.py new file mode 100644 index 00000000..35df7cab --- /dev/null +++ b/src/nanotron/models/llama_sft.py @@ -0,0 +1,821 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMa model.""" + +from typing import Dict, Optional, Union + +import torch +from flash_attn import flash_attn_varlen_func +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import Config, LlamaConfig, ParallelismArgs +from nanotron.config.models_config import RandomInit, SpectralMupInit +from nanotron.generation.generate_store import AttachableStore +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.nn.activations import ACT2FN +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator + +logger = logging.get_logger(__name__) + +#################################################################################### +############################## SFT Auxiliary functions ############################## +#################################################################################### +## Copied RoPE functions from HF Transformers. Nanotron ships with FlashAttention ## +## RoPEs written in triton which are considerbly faster BUT currently they don't ## +## support the poisiton ids necessary for the cross attention feature. The cos & ## +## sin are created in the embedding layer and propagated through the pipeline so ## +## we don't have a RoPE layer in each and every decoder layer. Then in each and ## +## every decoder layer we apply the cos & sin to Q & K with `apply_rotary_pos_emb`## +#################################################################################### + +# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/modeling_rope_utils.py#L29 +def _compute_default_rope_parameters( + config: LlamaConfig, +) -> torch.Tensor: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config (LlamaConfig): + The model configuration. + Returns: + inv_freq (torch.Tensor) + Contains the inverse frequencies for the RoPE embeddings + """ + with torch.autocast(device_type="cuda", enabled=False): + base = config.rope_theta # NOTE(tj.solergibert) 500000.0 + dim = int(config.hidden_size // config.num_attention_heads) # NOTE(tj.solergibert) 128 + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)).cuda() + return inv_freq + + +# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/5f841c74b62754f186a8c06a684d491524b7bc03/src/transformers/models/llama/modeling_llama.py#L81 +# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids +# NOTE(tj.solergibert) This function is just called once per batch to compute the position_embeddings, the expensive operation +# is def apply_rotary_pos_emb +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + config: LlamaConfig, + ): + super().__init__() + self.config = config + + self.inv_freq = _compute_default_rope_parameters(self.config) + # self.register_buffer("inv_freq", inv_freq, persistent=False) # NOTE(tj.solergibert) register_buffer casts to bf16! + # self.original_inv_freq = inv_freq + + @torch.no_grad() + def forward(self, x, position_ids): + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (torch.Tensor): The query tensor. + k (torch.Tensor): The key tensor. + cos (torch.Tensor): The cosine part of the rotary embedding. + sin (torch.Tensor): The sine part of the rotary embedding. + unsqueeze_dim (int, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + tuple (torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def prepare_varlen_args(position_ids): + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seqlens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_seqlen_in_batch = position_ids.max() + 1 + + return cu_seqlens, max_seqlen_in_batch + + +#################################################################################### + + +class GLUActivation(nn.Module): + def __init__(self, act_fn_name: str): + super().__init__() + self.act = ACT2FN[act_fn_name] + + def forward(self, merged_states: torch.Tensor): + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + return self.act(gate_states) * up_states + + +class MLP(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + config.intermediate_size, # shape of gate_linear + config.intermediate_size, # shape of up_linear + ) + self.gate_up_proj = TensorParallelColumnLinear( + config.hidden_size, + 2 * config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + ) + self.down_proj = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + # TODO @nouamane: why can't we torch.jit.script GLUActivation? + self.split_silu_mul = GLUActivation(config.hidden_act) + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states) + hidden_states = self.down_proj(self.split_silu_mul(merged_states)) + return hidden_states + + +class CausalSelfAttention(nn.Module, AttachableStore): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + + super().__init__() + # Tensor parallel considerations: We split tensors along head dimension + assert ( + config.num_attention_heads % tp_pg.size() == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})." + try: + assert ( + config.num_key_value_heads % tp_pg.size() == 0 + ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})." + except AttributeError: + log_rank( + "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads + config.num_key_value_heads = config.num_attention_heads + assert ( + config.num_attention_heads % config.num_key_value_heads == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})." + self.n_local_q_heads = config.num_attention_heads // tp_pg.size() + self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size() + self.n_repeats = config.num_attention_heads // config.num_key_value_heads + self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + self.d_model = config.hidden_size + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + # build the slice config for self.qkv for save/load + # shard are done within the contiguous chunk + qkv_contiguous_chunks = ( + config.num_attention_heads * self.d_qk, # shape of q + config.num_key_value_heads * self.d_qk, # shape of k + config.num_key_value_heads * self.d_qk, # shape of v + ) + self.qkv_proj = TensorParallelColumnLinear( + self.d_model, + config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=qkv_contiguous_chunks, + ) + + self.o_proj = TensorParallelRowLinear( + config.num_attention_heads * self.d_qk, + self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + + def forward( + self, + hidden_states, # [seq_length, batch_size, hidden_size] + position_ids, # [batch_size, seq_length] + cos, # [batch_size, seq_length, hidden_size//num_attention_heads] + sin, # [batch_size, seq_length, hidden_size//num_attention_heads] + ): + qkv_states = self.qkv_proj( + hidden_states + ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] + q_length, batch_size, _ = qkv_states.shape + + if self.is_gqa: + query_states, key_states, value_states = torch.split( + qkv_states, + [ + self.n_local_q_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + ], + dim=-1, + ) + + query_states = ( + query_states.transpose(0, 1) + .contiguous() + .view( + batch_size, q_length, self.n_local_q_heads, self.d_qk + ) # TODO(tj.solergibert) q_length to -1 BUT q_lenght is already well computed + ) + key_states = ( + key_states.transpose(0, 1) + .contiguous() + .view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) # TODO(tj.solergibert) q_length to -1 + ) + value_states = ( + value_states.transpose(0, 1) + .contiguous() + .view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) # TODO(tj.solergibert) q_length to -1 + ) + else: + query_states, key_states, value_states = ( + qkv_states.view( + q_length, batch_size, 3, self.n_local_q_heads, self.d_qk + ) # TODO(tj.solergibert) q_length to -1 + .permute(2, 1, 0, 3, 4) + .contiguous() + ) # [3, batch_size, seq_length, n_local_q_heads, d_qk] + + # TODO(tj.solergibert) Apply RoPE embeddings WITHOUT too many transpose... + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) + # Apply RoPE + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) + + # Prepare varlen args + cu_seqlens, max_seqlen_in_batch = prepare_varlen_args(position_ids) + + query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1)) + + attention_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_in_batch, + max_seqlen_k=max_seqlen_in_batch, + causal=True, + ) # NOTE(tj.solergibert) Returns out: (total, nheads, headdim). + + attention_output = ( + attention_output.contiguous() + .view(batch_size, q_length, self.n_local_q_heads * self.d_v) + .transpose(0, 1) # TODO(tj.solergibert) View is necessary, but contiguous? + ) + output = self.o_proj(attention_output) + + return output + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__() + self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + ) + + self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + cos: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + sin: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.attn(hidden_states=hidden_states, position_ids=position_ids, cos=cos, sin=sin) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states) + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "position_ids": position_ids, + "cos": cos, + "sin": sin, + } + + +class Embedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + # NOTE(tj.solergibert) SFT + self.position_embedding = LlamaRotaryEmbedding(config=config) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + + # NOTE(tj.solergibert) We create the cos & sin and propagate them through the pipeline so we + # don't have to create the LlamaRotaryEmbedding layer in each and every decoder layer + # We will still send the position ids for the varlen + cos, sin = self.position_embedding( + input_embeds, position_ids + ) # TODO(tj.solergibert) We just need from inputs_ids the device type + + return {"input_embeds": input_embeds, "position_ids": position_ids, "cos": cos, "sin": sin} + + +class LlamaModel(nn.Module): + """Build pipeline graph""" + + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "position_ids"}, + module_output_keys={"input_embeds", "position_ids", "cos", "sin"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=LlamaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "position_ids", "cos", "sin"}, + module_output_keys={"hidden_states", "position_ids", "cos", "sin"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) # TODO + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + return self.forward_with_hidden_states(input_ids=input_ids, position_ids=position_ids)[0] + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + hidden_encoder_states = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + + # NOTE(tj.solergibert) Rename input_embeds --> hidden_states + hidden_encoder_states["hidden_states"] = hidden_encoder_states.pop("input_embeds") + + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.intermediate_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + try: + num_key_values_heads = self.config.num_key_value_heads + except AttributeError: + num_key_values_heads = self.config.num_attention_heads + + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_key_value_heads=num_key_values_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.intermediate_size, + seq_len=sequence_length, + batch_size=global_batch_size, + ) + + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + loss = sharded_cross_entropy( + sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + ).transpose(0, 1) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} + + +class LlamaForSFT(NanotronModel): + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + sharded_logits = self.model( + input_ids=input_ids, + position_ids=position_ids, + ) + + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + + return {"loss": loss} + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.tie_word_embeddings is True: + return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + else: + return [] + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +def get_flops( + num_layers, + hidden_size, + num_heads, + num_key_value_heads, + vocab_size, + seq_len, + ffn_hidden_size, + batch_size=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + num_key_value_heads: number of key/value heads in the model + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + if num_key_value_heads is None: + num_key_value_heads = num_heads + hidden_size_per_head = hidden_size // num_heads + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention + ## qkv projection + decoder_qkv_proj_flops_fwd = ( + 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head + ) + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head + ## attn out + decoder_attn_out_flops_fwd = ( + 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size + ) + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + + decoder_flops_fwd = ( + decoder_qkv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd + + decoder_ffn_2_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO: This is a placeholder for now + + return model_flops, hardware_flops diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b6752f38..3000ae22 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -26,6 +26,7 @@ from nanotron import distributed as dist from nanotron import logging from nanotron.config import ( + ChatDatasetsArgs, Config, DatasetStageArgs, ExistingCheckpointInit, @@ -56,7 +57,8 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad -from nanotron.models.llama import LlamaForTraining, RotaryEmbedding +from nanotron.models.llama import LlamaForTraining +from nanotron.models.llama_sft import LlamaForSFT from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext @@ -102,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, + "LlamaConfigForSFT": LlamaForSFT, "Starcoder2Config": Starcoder2ForTraining, } @@ -252,6 +255,20 @@ def __init__( # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE(tj.solergibert) Flatten batch size in SFT training + if isinstance(self.config.data_stages[0].data.dataset, ChatDatasetsArgs) and self.micro_batch_size != 1: + self.sequence_length = self.micro_batch_size * self.config.tokens.sequence_length + self.micro_batch_size = 1 + self.global_batch_size = ( + self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size() + ) + log_rank( + f"Flattening Batch dimension for SFT training. global_batch_size: {self.global_batch_size}, micro_batch_size: {self.micro_batch_size}, sequence_length: {self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self.post_init() def pre_init(self): @@ -670,6 +687,10 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: def _init_model_instance(self) -> NanotronModel: model_config_cls = self.model_config.__class__.__name__ + + if model_config_cls == "LlamaConfig" and isinstance(self.config.data_stages[0].data.dataset, ChatDatasetsArgs): + model_config_cls = "LlamaConfigForSFT" + assert ( model_config_cls in CONFIG_TO_MODEL_CLASS ), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported" @@ -750,11 +771,12 @@ def _init_model( model_builder=model_builder, ) + # TODO(tj.solergibert) Fix this RoPE init only used with LlamaModel for generation? # Initialize rotary embeddings - for module in model.modules(): - if not isinstance(module, RotaryEmbedding): - continue - module.init_rotary_embeddings() + # for module in model.modules(): + # if not isinstance(module, RotaryEmbedding): + # continue + # module.init_rotary_embeddings() # Mark some parameters as tied self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) diff --git a/tools/check_sft.py b/tools/check_sft.py new file mode 100644 index 00000000..6e80b883 --- /dev/null +++ b/tools/check_sft.py @@ -0,0 +1,285 @@ +""" +torchrun --nproc-per-node 1 tools/check_sft.py +""" +import numpy as np +import torch +from nanotron.config import ParallelismArgs +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.data.chat_dataset import ChatDataset +from nanotron.data.dataloader_builder import build_chat_dataloader +from nanotron.models import build_model +from nanotron.models.llama_sft import LlamaForSFT +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.trainer import mark_tied_parameters +from torch.nn import CrossEntropyLoss +from torch.testing import assert_close +from transformers import AutoModelForCausalLM, LlamaConfig + +dtype = torch.bfloat16 +device = torch.device("cuda") +PATH_TO_LLAMA = "/mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct" + +# NOTE(tj.solergibert) This script is for testing porpuses. ONLY use 1 GPU +DP = 1 +PP = 1 +TP = 1 + +# NOTE(tj.solergibert) How many K-first tokens must match +# NOTE(tj.solergibert) After running lot's of tests, MOST (If not 100%) of the times the most probable token matches. Sometimes there are slightly differences in the next tokens, +# usually when the first token has a very high probability and the rest are left with < 1e-2. +TOPK_MATCH = 1 + +BATCHES = 15 + + +def hf_build_labels_completions_only(input_ids, is_completitions): + labels = np.where( + is_completitions, input_ids, -100 + ) # Mask tokens that don't belong to the completitions by the Assistant + return torch.tensor(np.array(labels, dtype=np.int64)) + + +def main(): + hf_model = AutoModelForCausalLM.from_pretrained( + PATH_TO_LLAMA, torch_dtype=dtype, attn_implementation="flash_attention_2" + ).to(device) + hf_config = LlamaConfig.from_pretrained(PATH_TO_LLAMA) + + parallel_config = ParallelismArgs( + dp=DP, + pp=PP, + tp=TP, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + nanotron_config = LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + + nanotron_model = build_model( + model_builder=lambda: LlamaForSFT( + config=nanotron_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + + # Copy Llama3-8B-Instruct parameters + # Token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + + with torch.no_grad(): + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) # = hf_model.model.embed_tokens.weight.data + + # Decoder layers + for i in range(nanotron_config.num_hidden_layers): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) # = hf_model.model.layers[i].input_layernorm.weight + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_( + tmp_qkv_proj + ) # = tmp_qkv_proj # torch.nn.Parameter(tmp_qkv_proj) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) # = hf_model.model.layers[i].self_attn.o_proj.weight + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.model.layers[i].mlp.gate_proj.weight, + hf_model.model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_( + tmp_gate_up_proj + ) # = tmp_gate_up_proj + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.down_proj.weight + ) # = hf_model.model.layers[i].mlp.down_proj.weight + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) # = hf_model.model.layers[i].post_attention_layernorm.weight + + # Last layer norm + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + nanotron_model.model.final_layer_norm.pp_block.weight.copy_( + hf_model.model.norm.weight + ) # = hf_model.model.norm.weight + # LM_Head + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) # = hf_model.lm_head.weight + + # Create ChatDataloaders + train_dataset = ChatDataset( + dataset_path="Magpie-Align/Magpie-Pro-300K-Filtered", # "Open-Orca/SlimOrca", + tokenizer_name_or_path=PATH_TO_LLAMA, + sequence_length=2048, + train_on_completions_only=False, + remove_cross_attention=True, + split="train", + conversation_column_name="conversations", + dp_rank=parallel_context.dp_pg.rank(), + dp_ranks_size=parallel_context.dp_pg.size(), + ) + + # Prepare dataloader + train_dataloader = build_chat_dataloader( + dataset=train_dataset, + parallel_context=parallel_context, + input_pp_rank=0, + output_pp_rank=0, + ) + + hf_model.eval() + nanotron_model.eval() + + for i, batch in enumerate(train_dataloader): + if i == BATCHES: + break + print(f"Checking sample {i}!") + + # Some DL Checks + assert batch["input_ids"].shape == batch["label_ids"].shape + assert batch["input_ids"].shape == batch["position_ids"].shape + assert batch["input_ids"].shape == batch["label_mask"].shape + + with torch.inference_mode(): + output_nanotron = nanotron_model.model( + input_ids=batch["input_ids"].cuda(), position_ids=batch["position_ids"].cuda() + ) + output_hf = hf_model(input_ids=batch["input_ids"].cuda(), position_ids=batch["position_ids"].cuda()) + + # Assertion of the logits + # This will always fail! We aren't performing the SAME operations. Nanotron packs QKV matrices, MLP & LayerNorm is different. So we don't have to focus on MATCHING LOGITS BUT GENERATIONS + # assert_close(output_hf.logits, output_nanotron.transpose(0, 1), rtol=1e-1, atol=1e-1) + + predicted_tokens = [62, 92, 125, 425, 744, 912, 1298] + for predicted_token in predicted_tokens: + print(predicted_token) + next_tokens_hf = torch.softmax(output_hf.logits[0, predicted_token, :], -1) + hf_topk_next_tokens = torch.topk(next_tokens_hf, 10) + + next_tokens_nanotron = torch.softmax(output_nanotron.transpose(0, 1)[0, predicted_token, :], -1) + nanotron_topk_next_tokens = torch.topk(next_tokens_nanotron, 10) + assert all( + hf_topk_next_tokens[1][:TOPK_MATCH] == nanotron_topk_next_tokens[1][:TOPK_MATCH] + ), f"HF: {hf_topk_next_tokens[1][:TOPK_MATCH]} \n\n{hf_topk_next_tokens[0][:TOPK_MATCH]}\n\n Nanotron: {nanotron_topk_next_tokens[1][:TOPK_MATCH]}\n\n{nanotron_topk_next_tokens[0][:TOPK_MATCH]}" + + print("All generations match!\nChecking Loss") + + # Loss check + nanotron_loss = nanotron_model.loss( + sharded_logits=output_nanotron, + label_ids=batch["label_ids"].cuda(), + label_mask=batch["label_mask"].cuda(), + )["loss"] + + # Creating labels_ids for HF loss computation + hf_labels = hf_build_labels_completions_only( + batch["label_ids"].flatten().tolist(), batch["label_mask"].flatten().tolist() + ) + shift_logits = output_hf.logits.contiguous() + shift_labels = hf_labels.contiguous() + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, 128256) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to("cuda") + hf_loss = loss_fct(shift_logits, shift_labels) + + assert_close(nanotron_loss, hf_loss, atol=1e-2, rtol=1e-2) # -3 is fine for most cases too + print("Loss match!") + + print("\n\n\nBoth generations and losses match!") + + +if __name__ == "__main__": + main() diff --git a/tools/todi/Dockerfile b/tools/todi/Dockerfile new file mode 100644 index 00000000..611ddba0 --- /dev/null +++ b/tools/todi/Dockerfile @@ -0,0 +1,15 @@ +FROM nvcr.io/nvidia/pytorch:24.05-py3 + +# Setup +RUN apt-get update && apt-get install python3-pip python3-venv -y +RUN pip install --upgrade pip setuptools==69.5.1 + +RUN pip install flash-attn==2.5.8 --no-build-isolation + +COPY nanotron/ /workspace/nanotron +WORKDIR /workspace/nanotron +RUN pip install -e '.[nanosets]' + +# Instructions: +# 1. Build image: podman build -f /users/asolergi/SFT/nanotron/tools/todi/Dockerfile -t nanotron_sft /users/asolergi/SFT/ #### NOTE In /users/asolergi/SFT/ we have nanotron/ (/users/asolergi/SFT/nanotron) +# 2. Export image: enroot import -o /store/swissai/a06/.sft_toni/nanotron_sft.sqsh podman://localhost/nanotron_sft:latest diff --git a/tools/todi/nanotron_sft.toml b/tools/todi/nanotron_sft.toml new file mode 100644 index 00000000..ffa30484 --- /dev/null +++ b/tools/todi/nanotron_sft.toml @@ -0,0 +1,15 @@ +image = "/store/swissai/a06/.sft_toni/nanotron_sft.sqsh" +mounts = [ +"/capstor", +"/users", +"/store", +] +workdir = "/workspace/nanotron/" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" + +[annotations.com.hooks] +aws_ofi_nccl.enabled = "true" +aws_ofi_nccl.variant = "cuda12" diff --git a/tools/todi/submit_nanotron_sft.sh b/tools/todi/submit_nanotron_sft.sh new file mode 100644 index 00000000..13a6696f --- /dev/null +++ b/tools/todi/submit_nanotron_sft.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +#SBATCH --job-name nanotron_sft +#SBATCH --chdir /users/asolergi/SFT/nanotron # TODO Set this path!!! +#SBATCH --output reports/R-%x.%j.out # Make sure this paths exists, otherwise the job will fail silently +#SBATCH --error reports/R-%x.%j.err # Make sure this paths exists, otherwise the job will fail silently +#SBATCH --nodes 4 # number of Nodes +#SBATCH --ntasks-per-node 1 # number of MP tasks. IMPORTANT: torchrun represents just 1 Slurm task +#SBATCH --gres gpu:4 # Number of GPUs +#SBATCH --cpus-per-task 288 # number of CPUs per task. +#SBATCH --time 11:59:59 # maximum execution time (DD-HH:MM:SS). Mandatory field in MN5 +#SBATCH --reservation todi +#SBATCH --environment /store/swissai/a06/.sft_toni/nanotron_sft.toml +#SBATCH --contiguous + +echo "START TIME: $(date)" + +# auto-fail on any errors in this script +set -eo pipefail + +# logging script's variables/commands for future debug needs +set -x + +###################### +### Set environment ### +###################### +GPUS_PER_NODE=4 +echo "NODES: $SLURM_NNODES" +###################### + +###################### +#### Set network ##### +###################### +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +###################### + +# note that we don't want to interpolate `\$SLURM_PROCID` till `srun` since otherwise all nodes will get +# 0 and the launcher will hang +# +# same goes for `\$(hostname -s|tr -dc '0-9')` - we want it to interpolate at `srun` time +LAUNCHER="torchrun \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $SLURM_NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + --node_rank ${SLURM_PROCID} \ + " + +PYTHON_FILE=/workspace/nanotron/run_train.py +NANOTRON_CONFIG=/users/asolergi/SFT/nanotron/examples/config_llama8b_sft.yaml # TODO Set this path!!! + +export CMD="CUDA_DEVICE_MAX_CONNECTIONS=1 $LAUNCHER $PYTHON_FILE --config $NANOTRON_CONFIG" + +echo $CMD + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +SRUN_ARGS=" \ + --cpus-per-task $SLURM_CPUS_PER_TASK \ + --jobid $SLURM_JOB_ID \ + --wait 60 \ + --unbuffered \ + " + +# bash -c is needed for the delayed interpolation of env vars to work +srun $SRUN_ARGS bash -c "$CMD" + +echo "END TIME: $(date)"