Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO Prompt Strategies only support single-turn and will fail silently on multi-turn datasets #1645

Open
6 of 8 tasks
bjoernpl opened this issue May 21, 2024 · 1 comment
Open
6 of 8 tasks
Labels
bug Something isn't working

Comments

@bjoernpl
Copy link

bjoernpl commented May 21, 2024

Please check that this issue hasn't been reported before.

  • I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I would expect the DPO prompting strategies to support multi-turn conversations for datasets such as https://huggingface.co/datasets/argilla/distilabel-capybara-dpo-7k-binarized. If it is not supported, it should at least warn the user.

Current behaviour

Currently the llama3 prompt strategy explicitly takes only a single-turn conversation. And due to the indexing, it wouldn't error out or fail if the conversation were longer:

def argilla_chat(
    cfg,
    **kwargs,
):  # pylint: disable=possibly-unused-variable,unused-argument
    """
    for argilla/dpo-mix-7k conversations
    """

    def transform_fn(sample):
        sample[
            "prompt"
        ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
        sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
        return sample

    return transform_fn

Steps to reproduce

from datasets import load_dataset
from axolotl.prompt_strategies.dpo.llama3 import argilla_chat

ds = load_dataset("argilla/distilabel-capybara-dpo-7k-binarized", split="train")
transform_fn = argilla_chat(None)
print(transform_fn(ds[0]))

See that the chosen and rejected are now equal because the first turn of a longer conversation is used.

Config yaml

No response

Possible solution

A possible solution would be to explicitly handle multiturn conversations:

def argilla_chat_multiturn(
    cfg,
    **kwargs,
):  # pylint: disable=possibly-unused-variable,unused-argument
    """
    for argilla/dpo-mix-7k conversations
    """

    def transform_fn(sample):
        if "system" in sample["chosen"][0]["role"] and sample["chosen"][0]["content"]:
            sample["prompt"] = (
                f"<|start_header_id|>system<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|>"
            )
        else:
            sample[
                "prompt"
            ] = ""
        for turn_index in range(0, len(sample["chosen"])-2, 2)
            user = sample["chosen"][i]["content"]
            assistant = sample["chosen"][i+1]["content"]
            sample[
                "prompt"
            ] += f"<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{assistant}<|eot_id|>"
        sample[
            "prompt"
        ] += f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][-2]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        sample["chosen"] = f"{sample['chosen'][-1]['content']}<|eot_id|>"
        sample["rejected"] = f"{sample['rejected'][-1]['content']}<|eot_id|>"
        return sample

    return transform_fn

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10

axolotl branch-commit

main/22ae21a

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.
@bjoernpl bjoernpl added the bug Something isn't working label May 21, 2024
@winglian
Copy link
Collaborator

I'm wondering if we can use chat_templates to handle this in a more sustainable way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants