Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO cleanup #1126

Merged
merged 23 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6a23720
cleanup dpo to be a little more extensible, add zephyr/nectar strategy
winglian Jan 12, 2024
2f7242d
fix eos slash
winglian Jan 13, 2024
74cbe2a
support for eval split
winglian Jan 13, 2024
1f28327
fix kwargs
winglian Jan 13, 2024
cb2f774
handle empty evals
winglian Jan 13, 2024
a91e0cb
don't load peft model for dpo
winglian Jan 13, 2024
bcae290
ensure dpo traning args gets bf16 for peft if applicable
winglian Jan 13, 2024
a24c756
fix duplicate kwargs for bf16
winglian Jan 13, 2024
d5e12dd
make sure to respect the configured lr scheduler
winglian Jan 15, 2024
9a746fa
supprt trainer callback to push config to wandb
winglian Jan 16, 2024
60f566c
set dataloader preload args
winglian Jan 16, 2024
c41391d
ensure that we are loading the lora when merging
winglian Jan 17, 2024
1cfd179
Update src/axolotl/utils/data.py
winglian Jan 18, 2024
02fc8f9
support local datasets for dpo
winglian Jan 23, 2024
7141fd1
chore: lint
winglian Jan 23, 2024
44a6f2d
dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names
winglian Jan 23, 2024
7f3b7ce
add split to dpo tests
winglian Jan 23, 2024
52a227d
fix rebase/merging error
winglian Jan 23, 2024
064b20e
handle edge case w logging
winglian Jan 23, 2024
c49315f
use accelerator for dpo datasets so it doesn't break the logger
winglian Jan 23, 2024
72fb877
missing args
winglian Jan 23, 2024
29663d8
validate checkpoint is an adapter for now
winglian Jan 23, 2024
76b5c2d
log warning when dataset strategy is not loadable
winglian Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 2 additions & 77 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
Expand All @@ -30,7 +29,7 @@
normalize_config,
validate_config,
)
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
Expand Down Expand Up @@ -343,81 +342,7 @@ def load_rl_datasets(
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
# eval_dataset = load_dataset(
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
# )
eval_dataset = None

def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample

def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(
ds_type_fn,
desc="Mapping RL Dataset",
)
train_dataset = concatenate_datasets(train_datasets)

# eval_dataset = eval_dataset.map(intel_apply_chatml)

train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
Expand Down
109 changes: 86 additions & 23 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Optional, Type, Union
from typing import List, Optional, Type, Union

import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers import (
EarlyStoppingCallback,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer

Expand Down Expand Up @@ -460,6 +465,7 @@ class TrainerBuilderBase(abc.ABC):
_train_dataset = None
_eval_dataset = None
_model_ref = None
_peft_config = None

def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
Expand Down Expand Up @@ -490,26 +496,33 @@ def eval_dataset(self):
def eval_dataset(self, dataset):
self._eval_dataset = dataset

@property
def peft_config(self):
return self._peft_config

@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config

@abstractmethod
def build(self, total_num_steps):
pass

@abstractmethod
def get_callbacks(self):
pass
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)

return callbacks

@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""


class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""

def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
Expand All @@ -526,10 +539,16 @@ def hook_post_create_trainer(self, trainer):
# TODO
return trainer


class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""

def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
callbacks.append(EvalFirstStepCallback())

if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
Expand All @@ -538,7 +557,7 @@ def get_callbacks(self):
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
callbacks.append(SaveBetterTransformerModelCallback())

if self.cfg.use_wandb:
callbacks.append(
Expand Down Expand Up @@ -931,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
"""

def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
Expand All @@ -949,21 +968,60 @@ def build_training_arguments(self, total_num_steps):
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)

if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True

if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy

if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors

if self.eval_dataset:
training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True

training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)

if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor

training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
max_steps=self.cfg.max_steps or total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
evaluation_strategy="no",
# eval_steps=self.cfg.eval_steps,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
bf16=True,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
or {"use_reentrant": False},
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
Expand All @@ -982,22 +1040,27 @@ def build(self, total_num_steps):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"

if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
# eval_dataset=self.eval_dataset,
eval_dataset=None,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

return dpo_trainer

Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/prompt_strategies/dpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
module for DPO style dataset transform strategies
"""

import importlib
import logging

LOG = logging.getLogger("axolotl")


def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is most likely not correct. The strategy includes underscores, not ., such as intel_apply_chatml.

Copy link
Contributor

@filippo82 filippo82 Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def load(strategy, cfg):
    try:
        load_fn = strategy.split("_")[-1]
        #strategy = ".".join(strategy.split("_")[:-1])
        LOG.info(load_fn)
        LOG.info(strategy)
        mod = importlib.import_module(f".{load_fn}", "axolotl.prompt_strategies.dpo")
        func = getattr(mod, strategy)
        load_kwargs = {}
        return func(cfg, **load_kwargs)
    except Exception as e:  # pylint: disable=broad-exception-caught
        LOG.warning(e)
        return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the intention is the setting is something like

type: chatml.argilla

in which case it will load the argilla function from the axolotl.prompt_strategies.dpo.chatml module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @winglian 👋🏻 thanks. That makes sense. I will test it later today 👍🏻

strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None
winglian marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading