Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SFT training #14

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

TJ-Solergibert
Copy link
Collaborator

@TJ-Solergibert TJ-Solergibert commented Jul 30, 2024

Hello,

In this PR, I include everything necessary to perform SFT with the llama model. The main features included are:

  1. Sample packing multiple conversations (Questions & Answers)
  2. Train on completions only. This means that we will only compute the loss with the tokens that belong to the answers generated by the model and not the questions asked by the human or the chat template.
  3. No cross-attention between samples. This way, each token will only attend to the tokens of its own conversation.

As detailed below, in this first beta, we will allow activating and deactivating features 2 & 3. I have designed this to measure the effect of these parameters, although I propose getting rid of them in the final version.

In my first PR in the nanotron repo (huggingface#187), I used as a reference the implementation on axolotl. The problem was that it contained padding tokens to fill the sequence length. I finally opted for a padding-free implementation and used the new implementation from HuggingFace Transformers as a reference [1], [2]. I included the script tools/check_sft.py to compare the generations of both models (HF & nanotron) and ensure they are the same. I emphasize that the generations are the same and not the logits. This is because, although we have the same parameters in both implementations, we do not perform exactly the same operations. In nanotron, we have 1. Fused QKV matrix, 2. Fused MLP matrix, 3. FA LayerNorm, which produces slightly different logits (with torch.testing.assert_close 99% of the logits are equal with atol=rtol=1e-2), but the important thing is that the generations are the same, especially the most probable first token.

Here & here you can observe the wandb runs of the 4 different configs toggling Features 2 & 3. As can be seen, using Feature 3 increases the TFLOPs since flash_attn_varlen_func achieves better performance when attending to shorter sequences.

Details & Functionality

In this first "Beta," I introduce 1. A new Dataset & ChatTokenizer & Collator and 2. A new Llama model for SFT (LlamaForSFT).

  1. We will only need to specify in the config file a QA dataset from the HuggingFace Hub. Unlike Nanosets, no preprocessing step is required. In this case, we have an IterableDataset that will handle tokenization + sample packing on the fly. The obvious benefit of this is that we don't need to tokenize the data beforehand, but it has a major drawback: It is not trivial to recover the state of the DataLoader to resume training once interrupted. The only solution I know is through torchdata's StatefulDataloaders, which I am already working on for the final version. We can also activate and deactivate features 2 and 3 via the configurations train_on_completions_only and remove_cross_attention. Finally, remember that we only support the format of conversation datasets from Open-Orca/SlimOrca & Magpie-Align/Magpie-Pro-300K-Filtered, so if you want to use other QA datasets (like this dataset with "content" and "role" keys), you will need to change the dictionary keys.

    - data:
        dataset:
          hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered
          hf_dataset_split: train
          conversation_column_name: conversations
          train_on_completions_only: true
          remove_cross_attention: true
        num_loading_workers: 1
        seed: 42
      name: General purpose training (Single dataset)
      start_training_step: 1

    Finally, to apply the chat template and tokenize the data, I included ChatTokenizer, very similar to the one included in meta-llama/llama3, with the difference that we will also register THE ROLE of the tokens necessary for feature 2.

  2. LlamaForSFT only supports SFT training. I have removed everything related to the inference of the nanotron checkpoints with the script run_generate.py since we have never tested it nor do we intend to. I included the RoPE embeddings from HF transformers, which, although their performance is not very good compared to FlashAttention's RoPEs written in Triton, are the only ones I have seen that support position ids (necessary for Feature 3). In the future, we could try to write a kernel for this. Also, for Feature 3, it is necessary to use flash_attn_varlen_func instead of flash_attn_func.

    Keep in mind that as we are already packing multiple samples, the tokens.micro_batch_size will be always 1. Then, the maximum number of tokens we will have is tokens.micro_batch_size * tokens.sequence_length.

TODOs

  • Write DOCS
  • Efficient RoPE embeddings
  • Ability to recover DataLoader states from an interruption
  • Delete options to Experiment with Feature 2 & 3
  • Delete tools/check_sft.py
  • Delete tools/todi
  • Delete convert_hf_nanotron.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant