Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deleting cross attention between documents during pertaining #16

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

TJ-Solergibert
Copy link
Collaborator

@TJ-Solergibert TJ-Solergibert commented Aug 22, 2024

In this PR I include the mechanism to delete cross attention between different documents during pertaining. I'm developing this PR from #14 as it is reusing most of the code for the Llama model.

To use this feature you will need to tokenise the data with the updated tool (Just contains a little patch that soon will be merged into main to NOT shuffle tokens with datatrove) and it's necessary to add the eos_token.
python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --eos-token "<|end_of_text|>" --output-folder datasets/SlimPajama-6B --n-tasks 16 hf --dataset DKYoon/SlimPajama-6B

Then in the config file you need to set remove_document_xattention to True:

- data:
    dataset:
      dataset_folder: /mloscratch/homes/solergib/SFT/nanotron/datasets/SlimPajama-6B
      remove_document_xattention: True
    num_loading_workers: 1
    seed: 42
  name: General purpose training (Single dataset)
  start_training_step: 1

This will build the LlamaForSFT model and config the collator to produce the correct position_ids & `label_mask.

In the collator we will create the correct position ids for each document, in short, reseting the position ids after every eos_token. For the label_mask we will mask the loss for the preceding token of the eos_token so the model doesn't learns to predict the eos_token and for the eos_token as it doesn't makes any sense to compute the loss of the prediction of the eos_token. In this image you can see what we are feeding into the model, with the correct values for the position_ids and label_mask.

image

And another example for a sample with > 2 documents (This is the boundary of the 3rd & 4th document):

image

The main difference between LlamaForSFT & LlamaForTraining is that LlamaForTraining leverages FA RoPE embeddings with triton for better performance but doesn't support position ids. If we manage to develop a custom triton kernel for RoPEs with position ids we could keep just LlamaForSFT.

@TJ-Solergibert
Copy link
Collaborator Author

Hi! In the last commit I include to LlamaForSFT 2 triton kernels from LinkedIn/Liger-Kernel. More precisely, a kernel to apply rope embeddings and the Silu activation function for the MLP block.

I updated the tools/check_remove_xattention.py script and we are matching HF generations. But there are 2 things to take into account:

  1. For the Silu kernel we are splitting again the MLP Gate projection & Up projection matrices. In nanotron they are fused to squeeze a bit more of performance, but the kernel expects to have 2 different matrices. Performance wise this is not such a problem as this 2 GEMMs have a very high arithmetic intensity BUT will require to take a look on how do we want to train our models and store checkpoints, as we might either IMPOSE 1 setting or craft some conversion logic (Between FUSED / NOT FUSED MLP gate up projections)
  2. The kernels are written in pure triton BUT we can't pip install liger-kernel as triton doesn't exists for ARM (tödi). Instead, we have pytorch-triton 3.0.0+989adb9a2 inside NGC containers which is essentially triton but the install manager is not able to resolve this conflict. That's why I included the code of the project w/ the license. IDK which would be the best solution, either to fork the original project and delete the dependency, contain the code in this repo or my least favourite one, ask the user to install it manually.

Thought-out this week I'll run some benchmarks and resolve the 2 issues previously mentioned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant