Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO cleanup #1126

Merged
merged 23 commits into from
Jan 23, 2024
Merged

DPO cleanup #1126

merged 23 commits into from
Jan 23, 2024

Conversation

winglian
Copy link
Collaborator

Description

This PR cleans up some hardcoding, improves the integration with trl's DPOTrainer and adds support for dpo prompt_strategies.

src/axolotl/utils/data.py Outdated Show resolved Hide resolved
Copy link
Contributor

@plaguss plaguss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR! I left a comment in case you see fit. Also, maybe it could be tackled in a different PR, but the preprocess command could also be updated to allow checking rl datasets:

+    if parsed_cfg.rl:
+        _ = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+    else:
+        _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
-    _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

src/axolotl/utils/data.py Outdated Show resolved Hide resolved
@winglian winglian force-pushed the dpo-cleanup branch 2 times, most recently from d5f97c3 to c0a1553 Compare January 23, 2024 02:21

def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is most likely not correct. The strategy includes underscores, not ., such as intel_apply_chatml.

Copy link
Contributor

@filippo82 filippo82 Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def load(strategy, cfg):
    try:
        load_fn = strategy.split("_")[-1]
        #strategy = ".".join(strategy.split("_")[:-1])
        LOG.info(load_fn)
        LOG.info(strategy)
        mod = importlib.import_module(f".{load_fn}", "axolotl.prompt_strategies.dpo")
        func = getattr(mod, strategy)
        load_kwargs = {}
        return func(cfg, **load_kwargs)
    except Exception as e:  # pylint: disable=broad-exception-caught
        LOG.warning(e)
        return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the intention is the setting is something like

type: chatml.argilla

in which case it will load the argilla function from the axolotl.prompt_strategies.dpo.chatml module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @winglian 👋🏻 thanks. That makes sense. I will test it later today 👍🏻

@winglian winglian merged commit 7523d1f into main Jan 23, 2024
1 of 6 checks passed
@winglian winglian deleted the dpo-cleanup branch January 23, 2024 05:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants