Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a chat_template prompt strategy for DPO #1725

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

fozziethebeat
Copy link
Contributor

Description

Replicates the chat_template support from SFT datasets but for DPO training. Users can now specify a dataset with a list of conversation messages along with rejected and chosen columns having a single conversation message. Further, all fields can be customized.

Motivation and Context

This change provides a more configurable set of datasets for DPO training.
Fixes #1708

How has this been tested?

  • Unittest added for the new strategy
  • Manual preprocessing run over a sample dataset
  • Full training completed on a real dataset

Screenshots (if appropriate)

Types of changes

  • Code changes to prompt strategies
  • Unittests

Social Handles (Optional)

@fozziethebeat

This mimics the sft chat_template strategy such that users can:
* Specify the messages field
* Specify the per message role and content fields
* speicfy the chosen and rejected fields
* Let the tokenizer construct the raw prompt
* Ensure the chosen and rejected fields don't have any prefix tokens
@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
for token in tokenizer.encode(tokens, add_special_tokens=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I added this since by default I saw that this step was including the bos token all the time. Since that's already included it seemed reasonable to not add it in a second time.

return tokenizer


class TestAssistantChatTemplateLlama3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestAssistantChatTemplateLlama3:
class TestAssistantDPOChatTemplateLlama3:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@winglian
Copy link
Collaborator

winglian commented Jul 5, 2024

@fozziethebeat but for DPO training, since trl handles the tokenization, do we need this piece?

@fozziethebeat
Copy link
Contributor Author

@fozziethebeat but for DPO training, since trl handles the tokenization, do we need this piece?

Was this in reference to the change in the debugging output? If so, it's not required but I think anyone manually inspecting tokenization output (like i did) would be very surprised to see the bos token duplicated in numerous scenarios. So it's more to give confidence that we constructed the strings correctly.

@fozziethebeat
Copy link
Contributor Author

Any other changes to add before updating the branch and approving for merging?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a chat_template strategy for DPO datasets
2 participants