Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for training chat models #187

Closed
wants to merge 3 commits into from

Conversation

TJ-Solergibert
Copy link
Contributor

Caution

🚨 This is a draft, still in development, and further testing needs to be done. Feel free to leave any comments!

This PR includes everything necessary to train chat models with:

  1. Sample packing
  2. No cross-attention contamination between packed samples
  3. Training on completions only (the answers generated by the model)

This image from @sz128 is very useful to understand 1. & 2.:
image

I am developing this feature with axolotl's implementation as a reference. The current status is as follows:

Dataset

IterableDatasets

This time, I have opted for and IterableDataset instead of a map style one. The obvious benefits are that we tokenize on the fly, which allows us to easily experiment with different models/tokenizers/chat templates and saves disk space by not storing the tokens. However, the drawbacks are:

  • The division between the different DP groups is done with the split_dataset_by_node function, which will divide the dataset's n_shards among the number of DP groups. If not evenly divisible, each DP group keeps 1 sample of the dataset, skipping the other examples. This is obviously not very optimal. Also, remember that since we tokenize the data on the fly, even if we divide by n_shards, each DP group will produce a different amount of tokens. Thus, after X steps, one DP group will have consumed its dataset 1.5 times, while another with longer samples will have done so 0.8 times.
  • Since each sample of the dataset produces a different number of tokens, it is impossible to predict how many samples from the dataset we will need to construct each batch of the dataloader. This, along with building the batches with a buffer that processes samples from the dataset until it has enough tokens, complicates the task of recovering the state of the dataloader when resuming training. For now, there is NO way to recover the state, but it is possible to achieve this using solutions like StatefulDataLoader that checkpoint the state of the DataLoader and Dataset. In our case, we would only need to checkpoint the number of samples we have extracted from the Dataset + the token buffers, this for each DP group.
  • The most concerning issue today is that to recover the state, we have to use the .skip() method. This method is not optimal as it consumes all the samples from the Dataset, but it seems they are working on a better solution. This is a problem if you have to skip many samples from a XXXL dataset.
  • It is very complicated (not impossible) to work with num_workers > 1, as the dataset would need to be divided at the worker level, and then it would not be trivial to recover the state if this value changes.

Of all these inconveniences, the one that worries me the most is the third one, but I trust that they will develop an optimal solution soon. We can easily develop solutions for the first and second issues, and the fourth one does not seem too problematic, although we could also address it.

How Samples Are Produced

In short, we extract samples from the dataset and apply the chat template until we can no longer fit a full Question-Answer pair into the sequence length of the sample we are constructing. We save this last Question-Answer pair for the next sample and pad the sample we are constructing (In the case of the Llama3 tokenizer as we don't have a pad token we use the <|eot_id|> token). We do this so that each sample has several completed Question-Answer pairs. This packing is greedy, although there are more complex strategies to minimize the number of pad tokens.

The important thing here is that we have developed the ChatTokenizer class to apply the chat template manually and not use solutions like the apply_chat_template method of tokenizers. We do this to know at the token level if each one belongs to the assistant's responses or not for the feature of training only on the assistant's tokens. I have added an assert to verify that the result of applying the chat template is exactly the same as the apply_chat_template method of tokenizers.

Dataset Samples

I have developed this notebook so you can check the batches produced by the DataLoader. In summary, the most relevant features are:

Note

The label id token '-' is actually -100. We switch it because tokenizer.convert_ids_to_tokens can't convert '-100' token.

  • When training just on the assistant's answers, the first token we predict is ĊĊ, which corresponds to "\n\n" from the Llama3 chat template. When interacting with a model we prompt a question + apply chat template and the model starts generating from this "\n\n" token.
Screenshot 2024-05-28 at 02 21 50
  • Every Question-Answer pair has its own position_ids. This will be relevant later for specifying FA2 to not attend to other samples.
Screenshot 2024-05-28 at 02 22 09

Other Considerations

  • We flatten the MBS to reduce the padding tokens.
  • For now, we only support the Open-Orca/SlimOrca dataset format. Check the Keys & Values of the dicts.
  • The Open-Orca/SlimOrca is a single turn conversation dataset, so we are appending the <|end_of_text|> token after each assistant's message.

Collator

The collator is very similar to DataCollatorForCLM, except we now add the position_ids. I have also removed several assertions.

DataLoader

The DataLoader is pretty simple. As I mentioned, it is not trivial to work with num_workers.

Config

I have added a new configuration called ChatDatasetsArgs which includes:

  • hf_dataset: The name or path of the dataset we will stream.
  • hf_dataset_split ("train"): The split of the dataset.
  • conversation_column_name: The name of the column from which to extract the conversations.
    For debugging purposes, deleted in the final release:
  • train_on_completions_only: Whether to just train on completions or not.
  • remove_cross_attention: Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism).

As I already mentioned, the final two configurations will be for evaluating the effect of these two functionalities. I would remove them for the final release since I do not see the benefit of not activating them.

What Is Still Missing:

  • I still need to implement everything related to FA2 and the mechanism to eliminate cross-attention contamination. The most relevant part which belongs to point 2 (No cross-attention contamination between packed samples) can be found here. I hope to finish it this week and present results with comparisons starting next week.

TODOs:

  • Checkpoint DataLoader states to resume training (StatefulDataLoader).
  • Support interleaving datasets.
  • Improve helper functions.
  • Add tests
  • Clarify chat dataset format (Dictionary keys)

@xrsrke
Copy link
Member

xrsrke commented Aug 2, 2024

Close. In favor of swiss-ai#14

@xrsrke xrsrke closed this Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants